From a980e0a0879d81fc914b6a2be9024c7178be5362 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 8 Apr 2025 16:39:33 -0700 Subject: [PATCH 01/56] pending ast --- guidance/_ast.py | 15 +++++++++++ guidance/models/_base/_model.py | 47 +++++++++++++++------------------ 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/guidance/_ast.py b/guidance/_ast.py index f259aa276..c39508974 100644 --- a/guidance/_ast.py +++ b/guidance/_ast.py @@ -173,6 +173,21 @@ def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: return interpreter._role_end(self, **kwargs) +@dataclass +class CaptureStart(ASTNode): + name: str + list_append: bool = False + + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + return interpreter.capture_start(self, **kwargs) + +@dataclass +class CaptureEnd(ASTNode): + name: str + + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + return interpreter.capture_end(self, **kwargs) + @dataclass class ImageBlob(ASTNode): data: str diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index 95c347420..9060689c7 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -8,7 +8,6 @@ from typing_extensions import Self -from ..._ast import ASTNode, Function, _parse_tags from ..._ast import ( ASTNode, Function, @@ -18,6 +17,8 @@ LiteralNode, RoleEnd, RoleStart, + CaptureStart, + CaptureEnd, _parse_tags, ) from ...trace import ( @@ -66,7 +67,8 @@ def __init__( ) -> None: self.echo = echo self._interpreter = interpreter - self._active_blocks: dict[Block, int] = {} + self._pending: tuple[ASTNode, ...] = () + self._active_blocks: tuple[Block, ...] = () self.token_count: int = 0 self._parent: Optional["Model"] = None @@ -93,7 +95,8 @@ def _update_trace_node( ) def __add__(self, other: Union[str, Function, ASTNode]) -> Self: - self = self._apply_blocks() + self = self.copy() + self._apply_blocks() if isinstance(other, str): if other == "": return self @@ -101,8 +104,7 @@ def __add__(self, other: Union[str, Function, ASTNode]) -> Self: if isinstance(other, Function): return other(self) if isinstance(other, ASTNode): - self = self._apply_node(other) - self = self._update_open_block_captures() + self._pending += (other,) return self return NotImplemented @@ -154,13 +156,12 @@ def stream(self) -> "ModelStream": """Return a new model stream object that delays execution until it is iterated over.""" return ModelStream(self) - def _apply_blocks(self) -> Self: - self = self.copy() + def _apply_blocks(self) -> None: global_active_blocks = _active_blocks.get() - for block, start_index in list(reversed(self._active_blocks.items())): + new_active_blocks = [] + for block in reversed(self._active_blocks): # Close blocks that are not globally active anymore if block not in global_active_blocks: - self._active_blocks.pop(block) if block.closer is not None: closer = block.closer if isinstance(closer, str): @@ -169,15 +170,19 @@ def _apply_blocks(self) -> Self: raise NotImplementedError( "Stateful block opener/closer functions are not yet supported" ) - self = self._apply_node(closer) - # Update capture regardless of whether or not it's been closed - if block.name is not None: - self = self.set(block.name, str(self)[start_index:]) + self._pending += (closer,) + if block.name is not None: + self._pending += (CaptureEnd(name=block.name),) + else: + # Not closed, so keep it + new_active_blocks.append(block) + new_active_blocks = list(reversed(new_active_blocks)) for block in global_active_blocks: # Open blocks that are not yet locally active if block not in self._active_blocks: - # Set start_index to the current length - self._active_blocks[block] = len(self) + new_active_blocks.append(block) + if block.name is not None: + self._pending += (CaptureStart(name=block.name),) if block.opener is not None: opener = block.opener if isinstance(opener, str): @@ -186,22 +191,14 @@ def _apply_blocks(self) -> Self: raise NotImplementedError( "Stateful block opener/closer functions are not yet supported" ) - self = self._apply_node(opener) - return self - - def _update_open_block_captures(self) -> Self: - self = self.copy() - for block, start_index in self._active_blocks.items(): - if block.name is not None: - self = self.set(block.name, str(self)[start_index:]) - return self + self._pending += (opener,) + self._active_blocks = tuple(new_active_blocks) def copy(self) -> Self: obj = object.__new__(self.__class__) obj.__dict__.update(self.__dict__) obj._interpreter = deepcopy(self._interpreter) - obj._active_blocks = {**self._active_blocks} obj._id = _gen_id() obj._parent_id = self._id obj._trace_nodes = set() From 1cd3ab1408dd34916e64dc26a9f27b330e63773f Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 8 Apr 2025 16:49:46 -0700 Subject: [PATCH 02/56] call _run before accessing state --- guidance/models/_base/_model.py | 48 ++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 16 deletions(-) diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index 9060689c7..14a7f4109 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -12,6 +12,7 @@ ASTNode, Function, GenAudio, + GrammarNode, ImageBlob, ImageUrl, LiteralNode, @@ -69,7 +70,6 @@ def __init__( self._interpreter = interpreter self._pending: tuple[ASTNode, ...] = () self._active_blocks: tuple[Block, ...] = () - self.token_count: int = 0 self._parent: Optional["Model"] = None self._parent_id: Optional[int] = None @@ -108,11 +108,25 @@ def __add__(self, other: Union[str, Function, ASTNode]) -> Self: return self return NotImplemented - def _apply_node(self, node: ASTNode) -> Self: - self = self.copy() - - # Input side of trace handler. - # TODO: StatefulGuidanceInput up in __add__? + def _run(self) -> None: + buffer: Optional[GrammarNode] = None + nodes = self._pending + self._pending = () + for node in nodes: + if isinstance(node, GrammarNode): + if buffer is None: + buffer = node + else: + buffer += node + else: + if buffer is not None: + self._run_node(buffer) + buffer = None + self._run_node(node) + if buffer is not None: + self._run_node(buffer) + + def _run_node(self, node: ASTNode) -> None: if isinstance(node, RoleStart): self._update_trace_node(self._id, self._parent_id, RoleOpenerInput(name=node.role)) elif isinstance(node, RoleEnd): @@ -132,9 +146,6 @@ def _apply_node(self, node: ASTNode) -> Self: self._update_trace_node(self._id, self._parent_id, StatelessGuidanceInput(value=node)) for i, output_attr in enumerate(self._interpreter.run(node)): - if isinstance(output_attr, TextOutput): - # TODO: put this elsewhere (inside state?) - self.token_count += output_attr.token_count if i != 0: # On the first iteration, we already have a fresh trace node # TODO: should be allowed to associate multiple output_attrs with a single input node? @@ -206,8 +217,13 @@ def copy(self) -> Self: obj._update_trace_node(obj._id, obj._parent_id, None) return obj + def _get_state(self) -> State: + """Get the state of the model.""" + self._run() + return self._interpreter.state + def __str__(self) -> str: - return str(self._interpreter.state) + return str(self._get_state()) def __len__(self): return len(str(self)) @@ -219,7 +235,7 @@ def __setitem__(self, key, value): def __getitem__(self, key: str) -> Any: try: - captures = self._interpreter.state.captures[key] + captures = self._get_state().captures[key] except KeyError: raise KeyError(f"Model does not contain the variable '{key}'") if isinstance(captures, list): @@ -228,7 +244,7 @@ def __getitem__(self, key: str) -> Any: return captures["value"] def __contains__(self, key: str) -> bool: - return key in self._interpreter.state.captures + return key in self._get_state().captures def get(self, key: str, default: Optional[D] = None) -> Union[str, list[str], None, D]: """Return the value of a variable, or a default value if the variable is not present. @@ -257,9 +273,9 @@ def set(self, key: str, value: Union[str, list[str]]) -> Self: """ self = self.copy() if isinstance(value, list): - self._interpreter.state.captures[key] = [{"value": v, "log_prob": None} for v in value] + self._get_state().captures[key] = [{"value": v, "log_prob": None} for v in value] else: - self._interpreter.state.captures[key] = {"value": value, "log_prob": None} + self._get_state().captures[key] = {"value": value, "log_prob": None} return self def remove(self, key: str) -> Self: @@ -271,7 +287,7 @@ def remove(self, key: str) -> Self: The variable name to remove. """ self = self.copy() - self._interpreter.state.captures.pop(key) + self._get_state().captures.pop(key) return self def log_prob( @@ -287,7 +303,7 @@ def log_prob( The value to return if the variable is not current set. """ try: - captures = self._interpreter.state.captures[key] + captures = self._get_state().captures[key] except KeyError: return default if isinstance(captures, list): From 07de0c3d26b3ff1ee8d4ddd8a444044e1256ae54 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Wed, 9 Apr 2025 17:30:00 -0700 Subject: [PATCH 03/56] poc of async model --- guidance/_ast.py | 47 ++--- guidance/models/_base/__init__.py | 7 +- guidance/models/_base/_interpreter.py | 56 ++--- guidance/models/_base/_model.py | 234 ++++++++++++++------- guidance/models/_openai.py | 282 ++++++++++++++++++++++++-- 5 files changed, 489 insertions(+), 137 deletions(-) diff --git a/guidance/_ast.py b/guidance/_ast.py index c39508974..2c605ef55 100644 --- a/guidance/_ast.py +++ b/guidance/_ast.py @@ -6,7 +6,8 @@ TYPE_CHECKING, Any, Callable, - Iterator, + Iterable, + AsyncIterable, Optional, Sequence, TypeVar, @@ -19,7 +20,7 @@ from .trace import OutputAttr if TYPE_CHECKING: - from .models._base import Interpreter, State + from .models._base import BaseInterpreter, State # to support the embedding of guidance functions inside Python f-strings we use tags with these delimiters tag_start = "{{G|" # start of a call tag @@ -146,11 +147,11 @@ def __radd__(model): S = TypeVar("S", bound="State") - +R = TypeVar("R", bound=Union[Iterable[OutputAttr], AsyncIterable[OutputAttr]]) class ASTNode(ABC): @abstractmethod - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: pass def simplify(self) -> "ASTNode": @@ -161,7 +162,7 @@ def simplify(self) -> "ASTNode": class RoleStart(ASTNode): role: str - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter._role_start(self, **kwargs) @@ -169,7 +170,7 @@ def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: class RoleEnd(ASTNode): role: str - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter._role_end(self, **kwargs) @@ -178,21 +179,21 @@ class CaptureStart(ASTNode): name: str list_append: bool = False - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.capture_start(self, **kwargs) @dataclass class CaptureEnd(ASTNode): name: str - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.capture_end(self, **kwargs) @dataclass class ImageBlob(ASTNode): data: str - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.image_blob(self, **kwargs) @@ -200,7 +201,7 @@ def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: class ImageUrl(ASTNode): url: str - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.image_url(self, **kwargs) @@ -208,7 +209,7 @@ def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: class AudioBlob(ASTNode): data: str - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.audio_blob(self, **kwargs) @@ -216,7 +217,7 @@ class GenAudio(ASTNode): def __init__(self, kwargs: dict[str, Any]): self.kwargs = kwargs - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.gen_audio(self, **kwargs) @@ -330,7 +331,7 @@ class LiteralNode(GrammarNode): def is_null(self) -> bool: return self.value == "" - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.text(self, **kwargs) @@ -338,7 +339,7 @@ def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: class RegexNode(GrammarNode): regex: Optional[str] - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.regex(self, **kwargs) @@ -366,7 +367,7 @@ def simplify(self) -> "GrammarNode": def children(self) -> Sequence["GrammarNode"]: return self.alternatives - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.select(self, **kwargs) @@ -389,7 +390,7 @@ def simplify(self) -> "GrammarNode": def children(self) -> Sequence["GrammarNode"]: return self.nodes - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.join(self, **kwargs) @@ -415,7 +416,7 @@ def children(self) -> Sequence["GrammarNode"]: def simplify(self) -> GrammarNode: return RepeatNode(self.node.simplify(), self.min, self.max) - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.repeat(self, **kwargs) @@ -428,7 +429,7 @@ def is_terminal(self) -> bool: # this can be used as part of bigger regexes return True - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.substring(self, **kwargs) @@ -479,7 +480,7 @@ def is_terminal(self) -> bool: def children(self) -> Sequence["GrammarNode"]: return (self.value,) - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.rule(self, **kwargs) @@ -499,7 +500,7 @@ def is_terminal(self) -> bool: # so it should never be terminal. return False - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: if self.target is None: raise ValueError("RuleRefNode target not set") return interpreter.rule(self.target) @@ -519,7 +520,7 @@ class SubgrammarNode(BaseSubgrammarNode): body: GrammarNode skip_regex: Optional[str] = None - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.subgrammar(self, **kwargs) @@ -527,7 +528,7 @@ def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: class JsonNode(BaseSubgrammarNode): schema: dict[str, Any] - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.json(self, **kwargs) @@ -535,7 +536,7 @@ def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: class LarkNode(BaseSubgrammarNode): lark_grammar: str - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: return interpreter.lark(self, **kwargs) diff --git a/guidance/models/_base/__init__.py b/guidance/models/_base/__init__.py index 0eb8ddfaf..658336b42 100644 --- a/guidance/models/_base/__init__.py +++ b/guidance/models/_base/__init__.py @@ -1,5 +1,5 @@ -from ._interpreter import Interpreter -from ._model import Model +from ._interpreter import Interpreter, AsyncInterpreter, BaseInterpreter +from ._model import AsyncModel, Model from ._state import State __all__ = [ @@ -11,4 +11,7 @@ "ASTNode", "ContentChunk", "MessageChunk", + "AsyncModel", + "AsyncInterpreter", + "BaseInterpreter", ] diff --git a/guidance/models/_base/_interpreter.py b/guidance/models/_base/_interpreter.py index e64aeb0a0..e7fdeffe7 100644 --- a/guidance/models/_base/_interpreter.py +++ b/guidance/models/_base/_interpreter.py @@ -1,5 +1,6 @@ import base64 -from typing import Generic, Iterator, TypeVar +from abc import ABC +from typing import Generic, Iterable, AsyncIterable, TypeVar, Union from ..._ast import ( ASTNode, @@ -26,85 +27,90 @@ from ._state import State S = TypeVar("S", bound=State) +R = TypeVar("R", bound=Union[Iterable[OutputAttr], AsyncIterable[OutputAttr]]) - -class Interpreter(Generic[S]): +class BaseInterpreter(Generic[S, R], ABC): def __init__(self, state: S): self.state = state - def run(self, node: ASTNode, **kwargs) -> Iterator[OutputAttr]: - yield from node.simplify()._run(self, **kwargs) + def run(self, node: ASTNode, **kwargs) -> R: + return node.simplify()._run(self, **kwargs) - def _role_start(self, node: RoleStart, **kwargs) -> Iterator[OutputAttr]: + def _role_start(self, node: RoleStart, **kwargs) -> R: if self.state.active_role is not None: raise ValueError( f"Cannot open role {node.role!r}: {self.state.active_role!r} is already open." ) return self.role_start(node, **kwargs) - def role_start(self, node: RoleStart, **kwargs) -> Iterator[OutputAttr]: + def role_start(self, node: RoleStart, **kwargs) -> R: raise UnsupportedNodeError(interpreter=self, node=node) - def _role_end(self, node: RoleEnd, **kwargs) -> Iterator[OutputAttr]: + def _role_end(self, node: RoleEnd, **kwargs) -> R: if self.state.active_role is None: raise ValueError("Cannot close role without active role") if self.state.active_role != node.role: raise ValueError(f"Cannot close role {node.role!r}: {self.state.active_role!r} is open.") return self.role_end(node, **kwargs) - def role_end(self, node: RoleEnd, **kwargs) -> Iterator[OutputAttr]: + def role_end(self, node: RoleEnd, **kwargs) -> R: raise UnsupportedNodeError(interpreter=self, node=node) - def text(self, node: LiteralNode, **kwargs) -> Iterator[OutputAttr]: + def text(self, node: LiteralNode, **kwargs) -> R: raise UnsupportedNodeError(interpreter=self, node=node) - def image_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: + def image_blob(self, node: ImageBlob, **kwargs) -> R: raise UnsupportedNodeError(interpreter=self, node=node) - def image_url(self, node: ImageUrl, **kwargs) -> Iterator[OutputAttr]: + def image_url(self, node: ImageUrl, **kwargs) -> R: image_bytes = bytes_from(node.url, allow_local=False) base64_string = base64.b64encode(image_bytes).decode("utf-8") return self.image_blob(ImageBlob(data=base64_string), **kwargs) - def grammar(self, node: GrammarNode, **kwargs) -> Iterator[OutputAttr]: + def grammar(self, node: GrammarNode, **kwargs) -> R: raise UnsupportedNodeError(interpreter=self, node=node) - def regex(self, node: RegexNode, **kwargs) -> Iterator[OutputAttr]: + def regex(self, node: RegexNode, **kwargs) -> R: return self.grammar(node, **kwargs) - def select(self, node: SelectNode, **kwargs) -> Iterator[OutputAttr]: + def select(self, node: SelectNode, **kwargs) -> R: return self.grammar(node, **kwargs) - def join(self, node: JoinNode, **kwargs) -> Iterator[OutputAttr]: + def join(self, node: JoinNode, **kwargs) -> R: return self.grammar(node, **kwargs) - def repeat(self, node: RepeatNode, **kwargs) -> Iterator[OutputAttr]: + def repeat(self, node: RepeatNode, **kwargs) -> R: return self.grammar(node, **kwargs) - def substring(self, node: SubstringNode, **kwargs) -> Iterator[OutputAttr]: + def substring(self, node: SubstringNode, **kwargs) -> R: return self.grammar(node, **kwargs) - def rule(self, node: RuleNode, **kwargs) -> Iterator[OutputAttr]: + def rule(self, node: RuleNode, **kwargs) -> R: return self.grammar(node, **kwargs) - def subgrammar(self, node: SubgrammarNode, **kwargs) -> Iterator[OutputAttr]: + def subgrammar(self, node: SubgrammarNode, **kwargs) -> R: return self.grammar(node, **kwargs) - def json(self, node: JsonNode, **kwargs) -> Iterator[OutputAttr]: + def json(self, node: JsonNode, **kwargs) -> R: return self.grammar(node, **kwargs) - def lark(self, node: LarkNode, **kwargs) -> Iterator[OutputAttr]: + def lark(self, node: LarkNode, **kwargs) -> R: return self.grammar(node, **kwargs) - def audio_blob(self, node: AudioBlob, **kwargs) -> Iterator[OutputAttr]: + def audio_blob(self, node: AudioBlob, **kwargs) -> R: raise UnsupportedNodeError(interpreter=self, node=node) - def gen_audio(self, node: GenAudio, **kwargs) -> Iterator[OutputAttr]: + def gen_audio(self, node: GenAudio, **kwargs) -> R: raise UnsupportedNodeError(interpreter=self, node=node) +class Interpreter(BaseInterpreter[S, Iterable[OutputAttr]]): + pass + +class AsyncInterpreter(BaseInterpreter[S, AsyncIterable[OutputAttr]]): + pass class UnsupportedNodeError(ValueError): - def __init__(self, interpreter: Interpreter, node: ASTNode): + def __init__(self, interpreter: BaseInterpreter, node: ASTNode): super().__init__(f"{interpreter} does not support {node!r} of type {type(node)}") self.interpreter = interpreter self.node = node diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index 14a7f4109..a2b3e7a97 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -4,7 +4,7 @@ import threading from contextvars import ContextVar, copy_context from copy import deepcopy -from typing import TYPE_CHECKING, Any, Iterator, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Iterator, Optional, TypeVar, Union, Generic from typing_extensions import Self @@ -29,19 +29,18 @@ RoleCloserInput, RoleOpenerInput, StatelessGuidanceInput, - TextOutput, TraceNode, ) from ...trace._trace import AudioInput from ...visual import TraceMessage -from ._interpreter import Interpreter +from ._interpreter import Interpreter, AsyncInterpreter from ._state import State if TYPE_CHECKING: from ...library._block import Block _active_blocks: ContextVar[tuple["Block", ...]] = ContextVar("active_blocks", default=()) -_event_queues: ContextVar[tuple[queue.Queue["Model"], ...]] = ContextVar( +_event_queues: ContextVar[tuple[queue.Queue["ModelABC"], ...]] = ContextVar( "event_queues", default=() ) _id_counter: int = 0 @@ -58,12 +57,12 @@ def _gen_id(): S = TypeVar("S", bound=State) D = TypeVar("D", bound=Any) +I = TypeVar("I", bound=Union[Interpreter[S], AsyncInterpreter[S]]) - -class Model: +class ModelABC(Generic[S, I]): def __init__( self, - interpreter: Interpreter[S], + interpreter: I, echo: bool = True, ) -> None: self.echo = echo @@ -71,7 +70,7 @@ def __init__( self._pending: tuple[ASTNode, ...] = () self._active_blocks: tuple[Block, ...] = () - self._parent: Optional["Model"] = None + self._parent: Optional["ModelABC"] = None self._parent_id: Optional[int] = None self._id: int = _gen_id() self._trace_nodes: set[TraceNode] = set() @@ -80,19 +79,20 @@ def __init__( def _update_trace_node( self, identifier: int, parent_id: Optional[int], node_attr: Optional[NodeAttr] = None ) -> None: - from ...registry import get_trace_handler, get_renderer - - trace_handler = get_trace_handler() - trace_node = trace_handler.update_node(identifier, parent_id, node_attr) - self._trace_nodes.add(trace_node) - if self.echo: - get_renderer().update( - TraceMessage( - trace_id=identifier, - parent_trace_id=parent_id, - node_attr=node_attr, - ), - ) + # from ...registry import get_trace_handler, get_renderer + + # trace_handler = get_trace_handler() + # trace_node = trace_handler.update_node(identifier, parent_id, node_attr) + # self._trace_nodes.add(trace_node) + # if self.echo: + # get_renderer().update( + # TraceMessage( + # trace_id=identifier, + # parent_trace_id=parent_id, + # node_attr=node_attr, + # ), + # ) + pass def __add__(self, other: Union[str, Function, ASTNode]) -> Self: self = self.copy() @@ -108,56 +108,6 @@ def __add__(self, other: Union[str, Function, ASTNode]) -> Self: return self return NotImplemented - def _run(self) -> None: - buffer: Optional[GrammarNode] = None - nodes = self._pending - self._pending = () - for node in nodes: - if isinstance(node, GrammarNode): - if buffer is None: - buffer = node - else: - buffer += node - else: - if buffer is not None: - self._run_node(buffer) - buffer = None - self._run_node(node) - if buffer is not None: - self._run_node(buffer) - - def _run_node(self, node: ASTNode) -> None: - if isinstance(node, RoleStart): - self._update_trace_node(self._id, self._parent_id, RoleOpenerInput(name=node.role)) - elif isinstance(node, RoleEnd): - self._update_trace_node(self._id, self._parent_id, RoleCloserInput(name=node.role)) - elif isinstance(node, LiteralNode): - self._update_trace_node(self._id, self._parent_id, LiteralInput(value=node.value)) - elif isinstance(node, ImageBlob): - self._update_trace_node(self._id, self._parent_id, ImageInput(value=node.data)) - elif isinstance(node, ImageUrl): - # TODO -- let's avoid downloading it here - pass - elif isinstance(node, GenAudio): - self._update_trace_node( - self._id, self._parent_id, AudioInput(value="") - ) # TODO -- what goes here? - else: - self._update_trace_node(self._id, self._parent_id, StatelessGuidanceInput(value=node)) - - for i, output_attr in enumerate(self._interpreter.run(node)): - if i != 0: - # On the first iteration, we already have a fresh trace node - # TODO: should be allowed to associate multiple output_attrs with a single input node? - # TODO: put this responsibility on the client in the case that it breaks a single input - # node into multiple input nodes to be handled sequentially? - self._parent_id = self._id - self._id = _gen_id() - self._update_trace_node(self._id, self._parent_id, output_attr) - # Stream current model state - self._send_to_event_queue() - return self - def _send_to_event_queue(self) -> None: """For streaming""" for event_queue in _event_queues.get(): @@ -217,6 +167,57 @@ def copy(self) -> Self: obj._update_trace_node(obj._id, obj._parent_id, None) return obj +class Model(ModelABC[S, Interpreter[S]]): + def _run(self) -> None: + buffer: Optional[GrammarNode] = None + nodes = self._pending + self._pending = () + for node in nodes: + if isinstance(node, GrammarNode): + if buffer is None: + buffer = node + else: + buffer += node + else: + if buffer is not None: + self._run_node(buffer) + buffer = None + self._run_node(node) + if buffer is not None: + self._run_node(buffer) + + def _run_node(self, node: ASTNode) -> None: + if isinstance(node, RoleStart): + self._update_trace_node(self._id, self._parent_id, RoleOpenerInput(name=node.role)) + elif isinstance(node, RoleEnd): + self._update_trace_node(self._id, self._parent_id, RoleCloserInput(name=node.role)) + elif isinstance(node, LiteralNode): + self._update_trace_node(self._id, self._parent_id, LiteralInput(value=node.value)) + elif isinstance(node, ImageBlob): + self._update_trace_node(self._id, self._parent_id, ImageInput(value=node.data)) + elif isinstance(node, ImageUrl): + # TODO -- let's avoid downloading it here + pass + elif isinstance(node, GenAudio): + self._update_trace_node( + self._id, self._parent_id, AudioInput(value="") + ) # TODO -- what goes here? + else: + self._update_trace_node(self._id, self._parent_id, StatelessGuidanceInput(value=node)) + + for i, output_attr in enumerate(self._interpreter.run(node)): + if i != 0: + # On the first iteration, we already have a fresh trace node + # TODO: should be allowed to associate multiple output_attrs with a single input node? + # TODO: put this responsibility on the client in the case that it breaks a single input + # node into multiple input nodes to be handled sequentially? + self._parent_id = self._id + self._id = _gen_id() + self._update_trace_node(self._id, self._parent_id, output_attr) + # Stream current model state + self._send_to_event_queue() + return self + def _get_state(self) -> State: """Get the state of the model.""" self._run() @@ -318,6 +319,97 @@ def __getattribute__(self, name): return super().__getattribute__(name) +class AsyncModel(ModelABC[S, AsyncInterpreter[S]]): + async def _run(self) -> None: + buffer: Optional[GrammarNode] = None + nodes = self._pending + self._pending = () + for node in nodes: + if isinstance(node, GrammarNode): + if buffer is None: + buffer = node + else: + buffer += node + else: + if buffer is not None: + await self._run_node(buffer) + buffer = None + await self._run_node(node) + if buffer is not None: + await self._run_node(buffer) + + async def _run_node(self, node: ASTNode) -> None: + if isinstance(node, RoleStart): + self._update_trace_node(self._id, self._parent_id, RoleOpenerInput(name=node.role)) + elif isinstance(node, RoleEnd): + self._update_trace_node(self._id, self._parent_id, RoleCloserInput(name=node.role)) + elif isinstance(node, LiteralNode): + self._update_trace_node(self._id, self._parent_id, LiteralInput(value=node.value)) + elif isinstance(node, ImageBlob): + self._update_trace_node(self._id, self._parent_id, ImageInput(value=node.data)) + elif isinstance(node, ImageUrl): + # TODO -- let's avoid downloading it here + pass + elif isinstance(node, GenAudio): + self._update_trace_node( + self._id, self._parent_id, AudioInput(value="") + ) # TODO -- what goes here? + else: + self._update_trace_node(self._id, self._parent_id, StatelessGuidanceInput(value=node)) + + i = 0 + async for output_attr in self._interpreter.run(node): + if i != 0: + # On the first iteration, we already have a fresh trace node + # TODO: should be allowed to associate multiple output_attrs with a single input node? + # TODO: put this responsibility on the client in the case that it breaks a single input + # node into multiple input nodes to be handled sequentially? + self._parent_id = self._id + self._id = _gen_id() + self._update_trace_node(self._id, self._parent_id, output_attr) + # Stream current model state + self._send_to_event_queue() + i += 1 + return self + + async def _get_state(self) -> State: + """Get the state of the model.""" + await self._run() + return self._interpreter.state + + async def get(self, key: str) -> Any: + try: + captures = (await self._get_state()).captures[key] + except KeyError: + raise KeyError(f"Model does not contain the variable '{key}'") + if isinstance(captures, list): + return [c["value"] for c in captures] + else: + return captures["value"] + + def __getitem__(self, key): + raise TypeError( + "AsyncModel does not support __getitem__. Use the async get() method instead." + ) + + async def to_string(self) -> str: + """Get the string representation of the model.""" + return str(await self._get_state()) + + def __str__(self) -> str: + raise TypeError( + "AsyncModel does not support __str__. Use the async to_string() method instead." + ) + + async def length(self) -> int: + """Get the length of the model.""" + return len(await self.to_string()) + + def __len__(self): + raise TypeError( + "AsyncModel does not support __len__. Use the async length() method instead." + ) + class ModelStream: def __init__( self, diff --git a/guidance/models/_openai.py b/guidance/models/_openai.py index 9d5fb7d6d..ac3e79b55 100644 --- a/guidance/models/_openai.py +++ b/guidance/models/_openai.py @@ -1,7 +1,8 @@ import base64 import wave from io import BytesIO -from typing import TYPE_CHECKING, Iterator, Literal, Optional, Union +from typing import TYPE_CHECKING, Iterable, AsyncIterable, Literal, Optional, Union +from copy import deepcopy from pydantic import BaseModel, Discriminator, Field, TypeAdapter from typing_extensions import Annotated, assert_never @@ -21,7 +22,7 @@ from .._utils import bytes_from from ..trace import ImageOutput, OutputAttr, TextOutput from ..trace._trace import AudioOutput -from ._base import Interpreter, Model, State +from ._base import AsyncInterpreter, Interpreter, AsyncModel, Model, State if TYPE_CHECKING: from openai.types.chat import ChatCompletionChunk @@ -166,31 +167,44 @@ def __init__( self.model = model self.client = openai.OpenAI(api_key=api_key, **kwargs) - def run(self, node: ASTNode, **kwargs) -> Iterator[OutputAttr]: + def __deepcopy__(self, memo): + """Custom deepcopy to ensure client is not copied.""" + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k == "client": + # Don't copy the client + setattr(result, k, v) + else: + setattr(result, k, deepcopy(v, memo)) + return result + + def run(self, node: ASTNode, **kwargs) -> Iterable[OutputAttr]: if not isinstance(node, RoleStart) and self.state.active_role is None: raise ValueError( "OpenAI models require an active role (e.g. use `with assistant(): ...`)" ) return super().run(node, **kwargs) - def role_start(self, node: RoleStart, **kwargs) -> Iterator[OutputAttr]: + def role_start(self, node: RoleStart, **kwargs) -> Iterable[OutputAttr]: self.state.active_role = node.role # TODO: drop this and yield nothing. We need to add this for now as a workaround for the # fact that current vis code assumes that there is actually a role start message yield TextOutput(value=get_role_start(node.role), is_input=True) - def role_end(self, node: RoleEnd, **kwargs) -> Iterator[OutputAttr]: + def role_end(self, node: RoleEnd, **kwargs) -> Iterable[OutputAttr]: self.state.messages.append(self.state.get_active_message()) self.state.audio = None self.state.content = [] self.state.active_role = None yield from () - def text(self, node: LiteralNode, **kwargs) -> Iterator[OutputAttr]: + def text(self, node: LiteralNode, **kwargs) -> Iterable[OutputAttr]: self.state.apply_text(node.value) yield TextOutput(value=node.value, is_input=True) - def rule(self, node: RuleNode, **kwargs) -> Iterator[OutputAttr]: + def rule(self, node: RuleNode, **kwargs) -> Iterable[OutputAttr]: if node.stop: raise ValueError("Stop condition not yet supported for OpenAI") if node.suffix: @@ -222,13 +236,13 @@ def rule(self, node: RuleNode, **kwargs) -> Iterator[OutputAttr]: else: yield from chunks - def regex(self, node: RegexNode, **kwargs) -> Iterator[OutputAttr]: + def regex(self, node: RegexNode, **kwargs) -> Iterable[OutputAttr]: if node.regex is not None: raise ValueError("Regex not yet supported for OpenAI") # We're in unconstrained mode now. return self._run(**kwargs) - def json(self, node: JsonNode, **kwargs) -> Iterator[OutputAttr]: + def json(self, node: JsonNode, **kwargs) -> Iterable[OutputAttr]: return self._run( response_format={ "type": "json_schema", @@ -241,7 +255,7 @@ def json(self, node: JsonNode, **kwargs) -> Iterator[OutputAttr]: **kwargs, ) - def _run(self, **kwargs) -> Iterator[OutputAttr]: + def _run(self, **kwargs) -> Iterable[OutputAttr]: if self.state.active_role is None: # Should never happen? raise ValueError( @@ -266,8 +280,8 @@ def _run(self, **kwargs) -> Iterator[OutputAttr]: yield from self._handle_stream(chunks) def _handle_stream( - self, chunks: Iterator["ChatCompletionChunk"] - ) -> Iterator[OutputAttr]: + self, chunks: Iterable["ChatCompletionChunk"] + ) -> Iterable[OutputAttr]: audio: Optional[AssistantAudio] = None for chunk in chunks: choice = chunk.choices[0] @@ -337,8 +351,215 @@ def _handle_stream( yield AudioOutput(value=base64.b64encode(wav_bytes).decode(), is_input=False) +class AsyncOpenAIInterpreter(AsyncInterpreter[OpenAIState]): + log_probs: bool = True + + def __init__( + self, + model: str, + api_key: Optional[str] = None, + **kwargs, + ): + try: + import openai + except ImportError: + raise Exception( + "Please install the openai package version >= 1 using `pip install openai -U` in order to use guidance.models.OpenAI!" + ) + self.state = OpenAIState() + self.model = model + self.client = openai.AsyncOpenAI(api_key=api_key, **kwargs) + + def __deepcopy__(self, memo): + """Custom deepcopy to ensure client is not copied.""" + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k == "client": + # Don't copy the client + setattr(result, k, v) + else: + setattr(result, k, deepcopy(v, memo)) + return result + + def run(self, node: ASTNode, **kwargs) -> AsyncIterable[OutputAttr]: + if not isinstance(node, RoleStart) and self.state.active_role is None: + raise ValueError( + "OpenAI models require an active role (e.g. use `with assistant(): ...`)" + ) + return super().run(node, **kwargs) + + async def role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: + self.state.active_role = node.role + # TODO: drop this and yield nothing. We need to add this for now as a workaround for the + # fact that current vis code assumes that there is actually a role start message + yield TextOutput(value=get_role_start(node.role), is_input=True) + + async def role_end(self, node: RoleEnd, **kwargs) -> AsyncIterable[OutputAttr]: + self.state.messages.append(self.state.get_active_message()) + self.state.audio = None + self.state.content = [] + self.state.active_role = None + if False: + # I know this is weird, but this is how async generators work + yield + + async def text(self, node: LiteralNode, **kwargs) -> AsyncIterable[OutputAttr]: + self.state.apply_text(node.value) + yield TextOutput(value=node.value, is_input=True) + + async def rule(self, node: RuleNode, **kwargs) -> AsyncIterable[OutputAttr]: + if node.stop: + raise ValueError("Stop condition not yet supported for OpenAI") + if node.suffix: + raise ValueError("Suffix not yet supported for OpenAI") + if node.stop_capture: + raise ValueError("Save stop text not yet supported for OpenAI") + + kwargs = kwargs.copy() + if node.temperature: + kwargs["temperature"] = node.temperature + if node.max_tokens: + kwargs["max_tokens"] = node.max_tokens + + chunks = self.run(node.value, **kwargs) + if node.capture: + buffered_text = "" + async for chunk in chunks: + # TODO: this isinstance check is pretty darn fragile. + # ~there must be a better way~ + if isinstance(chunk, TextOutput): + buffered_text += chunk.value + yield chunk + yield self.state.apply_capture( + name=node.capture, + value=buffered_text, + log_prob=1, # TODO + is_append=node.list_append, + ) + else: + async for chunk in chunks: + yield chunk + + def regex(self, node: RegexNode, **kwargs) -> AsyncIterable[OutputAttr]: + if node.regex is not None: + raise ValueError("Regex not yet supported for OpenAI") + # We're in unconstrained mode now. + return self._run(**kwargs) + + def json(self, node: JsonNode, **kwargs) -> AsyncIterable[OutputAttr]: + return self._run( + response_format={ + "type": "json_schema", + "json_schema": { + "name": "json_schema", # TODO? + "schema": node.schema, + "strict": True, + }, + }, + **kwargs, + ) + + async def _run(self, **kwargs) -> AsyncIterable[OutputAttr]: + if self.state.active_role is None: + # Should never happen? + raise ValueError( + "OpenAI models require chat blocks (e.g. use `with assistant(): ...`)" + ) + if self.state.active_role != "assistant": + raise ValueError( + "OpenAI models can only generate as the assistant (i.e. inside of `with assistant(): ...`)" + ) + if self.state.content: + raise ValueError( + f"OpenAI models do not support pre-filled assistant messages: got data {self.state.content}." + ) + + async with await self.client.chat.completions.create( + model=self.model, + messages=TypeAdapter(list[Message]).dump_python(self.state.messages), # type: ignore[arg-type] + logprobs=self.log_probs, + stream=True, + **kwargs, + ) as chunks: + async for output in self._handle_stream(chunks): + yield output + + async def _handle_stream( + self, chunks: AsyncIterable["ChatCompletionChunk"] + ) -> AsyncIterable[OutputAttr]: + audio: Optional[AssistantAudio] = None + async for chunk in chunks: + choice = chunk.choices[0] + delta = choice.delta + if delta.content is not None: + assert audio is None + content = delta.content + if len(content) == 0: + continue + self.state.apply_text(content) + if choice.logprobs is not None: + # TODO: actually get tokens from this and be less lazy + prob = 2.718 ** choice.logprobs.content[0].logprob + else: + prob = float("nan") + yield TextOutput(value=delta.content, is_generated=True, prob=prob) + elif getattr(delta, "audio", None) is not None: + transcript_chunk: Optional[str] = None + if audio is None: + assert delta.audio.get("id") is not None + audio = AssistantAudio( + id=delta.audio["id"], + expires_at=delta.audio.get("expires_at", 0), # ? + transcript=delta.audio.get("transcript", ""), + data=delta.audio.get("data", ""), + ) + transcript_chunk = delta.audio.get("transcript") + else: + assert delta.audio.get("id") is None or delta.audio["id"] == audio.id + if delta.audio.get("data") is not None: + audio.data += delta.audio["data"] + if delta.audio.get("transcript") is not None: + audio.transcript += delta.audio["transcript"] + transcript_chunk = delta.audio["transcript"] + if delta.audio.get("expires_at") is not None: + assert audio.expires_at == 0 + audio.expires_at = delta.audio["expires_at"] + if transcript_chunk is not None: + # Why not give the users some transcript? :) + yield TextOutput( + value=delta.audio["transcript"], + is_generated=True, + ) + elif delta.function_call is not None: + raise NotImplementedError("Function calling not yet supported for OpenAI") + elif delta.tool_calls is not None: + raise NotImplementedError("Tool calling not yet supported for OpenAI") + elif delta.refusal is not None: + raise ValueError(f"OpenAI refused the request: {delta.refusal}") + + if choice.finish_reason is not None: + break + + if audio is not None: + assert self.state.audio is None + self.state.audio = audio + # Create an in-memory WAV file + wav_buffer = BytesIO() + with wave.open(wav_buffer, "wb") as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) # PCM16 = 2 bytes per sample + wav_file.setframerate(22050) # A guess + wav_file.writeframes(base64.b64decode(audio.data)) + + # Get WAV bytes + wav_bytes = wav_buffer.getvalue() + yield AudioOutput(value=base64.b64encode(wav_bytes).decode(), is_input=False) + + class OpenAIImageInterpreter(OpenAIInterpreter): - def image_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: + def image_blob(self, node: ImageBlob, **kwargs) -> Iterable[OutputAttr]: try: import PIL.Image except ImportError: @@ -360,7 +581,7 @@ def image_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: ) yield ImageOutput(value=node.data, input=True) - def image_url(self, node: ImageUrl, **kwargs) -> Iterator[OutputAttr]: + def image_url(self, node: ImageUrl, **kwargs) -> Iterable[OutputAttr]: self.state.content.append({"type": "image_url", "image_url": {"url": node.url}}) image_bytes = bytes_from(node.url, allow_local=False) base64_string = base64.b64encode(image_bytes).decode("utf-8") @@ -370,7 +591,7 @@ def image_url(self, node: ImageUrl, **kwargs) -> Iterator[OutputAttr]: class OpenAIAudioInterpreter(OpenAIInterpreter): log_probs: bool = False - def audio_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: + def audio_blob(self, node: ImageBlob, **kwargs) -> Iterable[OutputAttr]: format = "wav" # TODO: infer from node self.state.content.append( AudioContent( @@ -383,7 +604,7 @@ def audio_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: ) yield AudioOutput(value=node.data, format=format, input=True) - def gen_audio(self, node: GenAudio, **kwargs) -> Iterator[OutputAttr]: + def gen_audio(self, node: GenAudio, **kwargs) -> Iterable[OutputAttr]: yield from self._run( modalities=["text", "audio"], # Has to be both? audio={ @@ -427,3 +648,32 @@ def __init__( super().__init__( interpreter=interpreter_cls(model, api_key=api_key, **kwargs), echo=echo ) + + +class AsyncOpenAI(AsyncModel): + def __init__( + self, + model: str, + echo: bool = True, + api_key: Optional[str] = None, + **kwargs, + ): + """Build a new OpenAI model object that represents a model in a given state. + + Parameters + ---------- + model : str + The name of the OpenAI model to use (e.g. gpt-4o-mini). + echo : bool + If true the final result of creating this model state will be displayed (as HTML in a notebook). + api_key : None or str + The OpenAI API key to use for remote requests, passed directly to the `openai.OpenAI` constructor. + + **kwargs : + All extra keyword arguments are passed directly to the `openai.OpenAI` constructor. Commonly used argument + names include `base_url` and `organization` + """ + super().__init__( + interpreter=AsyncOpenAIInterpreter(model, api_key=api_key, **kwargs), + echo=echo, + ) \ No newline at end of file From eed1e59ef2cb09fb5765693976e28eb288eafae6 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Mon, 14 Apr 2025 18:47:30 -0700 Subject: [PATCH 04/56] put everything on top of async backend --- guidance/_ast.py | 42 ++--- guidance/models/_base/__init__.py | 8 +- guidance/models/_base/_interpreter.py | 51 +++--- guidance/models/_base/_model.py | 134 +++++--------- guidance/models/_openai.py | 246 +------------------------- 5 files changed, 100 insertions(+), 381 deletions(-) diff --git a/guidance/_ast.py b/guidance/_ast.py index 2c605ef55..886f7d334 100644 --- a/guidance/_ast.py +++ b/guidance/_ast.py @@ -20,7 +20,7 @@ from .trace import OutputAttr if TYPE_CHECKING: - from .models._base import BaseInterpreter, State + from .models._base import Interpreter, State # to support the embedding of guidance functions inside Python f-strings we use tags with these delimiters tag_start = "{{G|" # start of a call tag @@ -151,7 +151,7 @@ def __radd__(model): class ASTNode(ABC): @abstractmethod - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: pass def simplify(self) -> "ASTNode": @@ -162,7 +162,7 @@ def simplify(self) -> "ASTNode": class RoleStart(ASTNode): role: str - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter._role_start(self, **kwargs) @@ -170,7 +170,7 @@ def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: class RoleEnd(ASTNode): role: str - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter._role_end(self, **kwargs) @@ -179,21 +179,21 @@ class CaptureStart(ASTNode): name: str list_append: bool = False - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.capture_start(self, **kwargs) @dataclass class CaptureEnd(ASTNode): name: str - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.capture_end(self, **kwargs) @dataclass class ImageBlob(ASTNode): data: str - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.image_blob(self, **kwargs) @@ -201,7 +201,7 @@ def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: class ImageUrl(ASTNode): url: str - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.image_url(self, **kwargs) @@ -209,7 +209,7 @@ def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: class AudioBlob(ASTNode): data: str - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.audio_blob(self, **kwargs) @@ -217,7 +217,7 @@ class GenAudio(ASTNode): def __init__(self, kwargs: dict[str, Any]): self.kwargs = kwargs - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.gen_audio(self, **kwargs) @@ -331,7 +331,7 @@ class LiteralNode(GrammarNode): def is_null(self) -> bool: return self.value == "" - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.text(self, **kwargs) @@ -339,7 +339,7 @@ def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: class RegexNode(GrammarNode): regex: Optional[str] - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.regex(self, **kwargs) @@ -367,7 +367,7 @@ def simplify(self) -> "GrammarNode": def children(self) -> Sequence["GrammarNode"]: return self.alternatives - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.select(self, **kwargs) @@ -390,7 +390,7 @@ def simplify(self) -> "GrammarNode": def children(self) -> Sequence["GrammarNode"]: return self.nodes - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.join(self, **kwargs) @@ -416,7 +416,7 @@ def children(self) -> Sequence["GrammarNode"]: def simplify(self) -> GrammarNode: return RepeatNode(self.node.simplify(), self.min, self.max) - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.repeat(self, **kwargs) @@ -429,7 +429,7 @@ def is_terminal(self) -> bool: # this can be used as part of bigger regexes return True - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.substring(self, **kwargs) @@ -480,7 +480,7 @@ def is_terminal(self) -> bool: def children(self) -> Sequence["GrammarNode"]: return (self.value,) - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.rule(self, **kwargs) @@ -500,7 +500,7 @@ def is_terminal(self) -> bool: # so it should never be terminal. return False - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: if self.target is None: raise ValueError("RuleRefNode target not set") return interpreter.rule(self.target) @@ -520,7 +520,7 @@ class SubgrammarNode(BaseSubgrammarNode): body: GrammarNode skip_regex: Optional[str] = None - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.subgrammar(self, **kwargs) @@ -528,7 +528,7 @@ def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: class JsonNode(BaseSubgrammarNode): schema: dict[str, Any] - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.json(self, **kwargs) @@ -536,7 +536,7 @@ def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: class LarkNode(BaseSubgrammarNode): lark_grammar: str - def _run(self, interpreter: "BaseInterpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.lark(self, **kwargs) diff --git a/guidance/models/_base/__init__.py b/guidance/models/_base/__init__.py index 658336b42..2808e0936 100644 --- a/guidance/models/_base/__init__.py +++ b/guidance/models/_base/__init__.py @@ -1,5 +1,5 @@ -from ._interpreter import Interpreter, AsyncInterpreter, BaseInterpreter -from ._model import AsyncModel, Model +from ._interpreter import Interpreter, Interpreter +from ._model import Model from ._state import State __all__ = [ @@ -11,7 +11,5 @@ "ASTNode", "ContentChunk", "MessageChunk", - "AsyncModel", - "AsyncInterpreter", - "BaseInterpreter", + "Interpreter", ] diff --git a/guidance/models/_base/_interpreter.py b/guidance/models/_base/_interpreter.py index e7fdeffe7..d150bef4c 100644 --- a/guidance/models/_base/_interpreter.py +++ b/guidance/models/_base/_interpreter.py @@ -27,90 +27,83 @@ from ._state import State S = TypeVar("S", bound=State) -R = TypeVar("R", bound=Union[Iterable[OutputAttr], AsyncIterable[OutputAttr]]) -class BaseInterpreter(Generic[S, R], ABC): +class Interpreter(Generic[S], ABC): def __init__(self, state: S): self.state = state - def run(self, node: ASTNode, **kwargs) -> R: + def run(self, node: ASTNode, **kwargs) -> AsyncIterable[OutputAttr]: return node.simplify()._run(self, **kwargs) - def _role_start(self, node: RoleStart, **kwargs) -> R: + def _role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: if self.state.active_role is not None: raise ValueError( f"Cannot open role {node.role!r}: {self.state.active_role!r} is already open." ) return self.role_start(node, **kwargs) - def role_start(self, node: RoleStart, **kwargs) -> R: + def role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: raise UnsupportedNodeError(interpreter=self, node=node) - def _role_end(self, node: RoleEnd, **kwargs) -> R: + def _role_end(self, node: RoleEnd, **kwargs) -> AsyncIterable[OutputAttr]: if self.state.active_role is None: raise ValueError("Cannot close role without active role") if self.state.active_role != node.role: raise ValueError(f"Cannot close role {node.role!r}: {self.state.active_role!r} is open.") return self.role_end(node, **kwargs) - def role_end(self, node: RoleEnd, **kwargs) -> R: + def role_end(self, node: RoleEnd, **kwargs) -> AsyncIterable[OutputAttr]: raise UnsupportedNodeError(interpreter=self, node=node) - def text(self, node: LiteralNode, **kwargs) -> R: + def text(self, node: LiteralNode, **kwargs) -> AsyncIterable[OutputAttr]: raise UnsupportedNodeError(interpreter=self, node=node) - def image_blob(self, node: ImageBlob, **kwargs) -> R: + def image_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAttr]: raise UnsupportedNodeError(interpreter=self, node=node) - def image_url(self, node: ImageUrl, **kwargs) -> R: + def image_url(self, node: ImageUrl, **kwargs) -> AsyncIterable[OutputAttr]: image_bytes = bytes_from(node.url, allow_local=False) base64_string = base64.b64encode(image_bytes).decode("utf-8") return self.image_blob(ImageBlob(data=base64_string), **kwargs) - def grammar(self, node: GrammarNode, **kwargs) -> R: + def grammar(self, node: GrammarNode, **kwargs) -> AsyncIterable[OutputAttr]: raise UnsupportedNodeError(interpreter=self, node=node) - def regex(self, node: RegexNode, **kwargs) -> R: + def regex(self, node: RegexNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def select(self, node: SelectNode, **kwargs) -> R: + def select(self, node: SelectNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def join(self, node: JoinNode, **kwargs) -> R: + def join(self, node: JoinNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def repeat(self, node: RepeatNode, **kwargs) -> R: + def repeat(self, node: RepeatNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def substring(self, node: SubstringNode, **kwargs) -> R: + def substring(self, node: SubstringNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def rule(self, node: RuleNode, **kwargs) -> R: + def rule(self, node: RuleNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def subgrammar(self, node: SubgrammarNode, **kwargs) -> R: + def subgrammar(self, node: SubgrammarNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def json(self, node: JsonNode, **kwargs) -> R: + def json(self, node: JsonNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def lark(self, node: LarkNode, **kwargs) -> R: + def lark(self, node: LarkNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def audio_blob(self, node: AudioBlob, **kwargs) -> R: + def audio_blob(self, node: AudioBlob, **kwargs) -> AsyncIterable[OutputAttr]: raise UnsupportedNodeError(interpreter=self, node=node) - def gen_audio(self, node: GenAudio, **kwargs) -> R: + def gen_audio(self, node: GenAudio, **kwargs) -> AsyncIterable[OutputAttr]: raise UnsupportedNodeError(interpreter=self, node=node) -class Interpreter(BaseInterpreter[S, Iterable[OutputAttr]]): - pass - -class AsyncInterpreter(BaseInterpreter[S, AsyncIterable[OutputAttr]]): - pass - class UnsupportedNodeError(ValueError): - def __init__(self, interpreter: BaseInterpreter, node: ASTNode): + def __init__(self, interpreter: Interpreter, node: ASTNode): super().__init__(f"{interpreter} does not support {node!r} of type {type(node)}") self.interpreter = interpreter self.node = node diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index a2b3e7a97..d5e8ca960 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -2,9 +2,11 @@ import queue import threading +import asyncio from contextvars import ContextVar, copy_context from copy import deepcopy from typing import TYPE_CHECKING, Any, Iterator, Optional, TypeVar, Union, Generic +import warnings from typing_extensions import Self @@ -32,15 +34,14 @@ TraceNode, ) from ...trace._trace import AudioInput -from ...visual import TraceMessage -from ._interpreter import Interpreter, AsyncInterpreter +from ._interpreter import Interpreter from ._state import State if TYPE_CHECKING: from ...library._block import Block _active_blocks: ContextVar[tuple["Block", ...]] = ContextVar("active_blocks", default=()) -_event_queues: ContextVar[tuple[queue.Queue["ModelABC"], ...]] = ContextVar( +_event_queues: ContextVar[tuple[queue.Queue["Model"], ...]] = ContextVar( "event_queues", default=() ) _id_counter: int = 0 @@ -57,12 +58,11 @@ def _gen_id(): S = TypeVar("S", bound=State) D = TypeVar("D", bound=Any) -I = TypeVar("I", bound=Union[Interpreter[S], AsyncInterpreter[S]]) -class ModelABC(Generic[S, I]): +class Model: def __init__( self, - interpreter: I, + interpreter: Interpreter[S], echo: bool = True, ) -> None: self.echo = echo @@ -70,7 +70,7 @@ def __init__( self._pending: tuple[ASTNode, ...] = () self._active_blocks: tuple[Block, ...] = () - self._parent: Optional["ModelABC"] = None + self._parent: Optional["Model"] = None self._parent_id: Optional[int] = None self._id: int = _gen_id() self._trace_nodes: set[TraceNode] = set() @@ -167,62 +167,6 @@ def copy(self) -> Self: obj._update_trace_node(obj._id, obj._parent_id, None) return obj -class Model(ModelABC[S, Interpreter[S]]): - def _run(self) -> None: - buffer: Optional[GrammarNode] = None - nodes = self._pending - self._pending = () - for node in nodes: - if isinstance(node, GrammarNode): - if buffer is None: - buffer = node - else: - buffer += node - else: - if buffer is not None: - self._run_node(buffer) - buffer = None - self._run_node(node) - if buffer is not None: - self._run_node(buffer) - - def _run_node(self, node: ASTNode) -> None: - if isinstance(node, RoleStart): - self._update_trace_node(self._id, self._parent_id, RoleOpenerInput(name=node.role)) - elif isinstance(node, RoleEnd): - self._update_trace_node(self._id, self._parent_id, RoleCloserInput(name=node.role)) - elif isinstance(node, LiteralNode): - self._update_trace_node(self._id, self._parent_id, LiteralInput(value=node.value)) - elif isinstance(node, ImageBlob): - self._update_trace_node(self._id, self._parent_id, ImageInput(value=node.data)) - elif isinstance(node, ImageUrl): - # TODO -- let's avoid downloading it here - pass - elif isinstance(node, GenAudio): - self._update_trace_node( - self._id, self._parent_id, AudioInput(value="") - ) # TODO -- what goes here? - else: - self._update_trace_node(self._id, self._parent_id, StatelessGuidanceInput(value=node)) - - for i, output_attr in enumerate(self._interpreter.run(node)): - if i != 0: - # On the first iteration, we already have a fresh trace node - # TODO: should be allowed to associate multiple output_attrs with a single input node? - # TODO: put this responsibility on the client in the case that it breaks a single input - # node into multiple input nodes to be handled sequentially? - self._parent_id = self._id - self._id = _gen_id() - self._update_trace_node(self._id, self._parent_id, output_attr) - # Stream current model state - self._send_to_event_queue() - return self - - def _get_state(self) -> State: - """Get the state of the model.""" - self._run() - return self._interpreter.state - def __str__(self) -> str: return str(self._get_state()) @@ -318,8 +262,6 @@ def __getattribute__(self, name): return getattr(self._interpreter, "engine") return super().__getattribute__(name) - -class AsyncModel(ModelABC[S, AsyncInterpreter[S]]): async def _run(self) -> None: buffer: Optional[GrammarNode] = None nodes = self._pending @@ -372,14 +314,47 @@ async def _run_node(self, node: ASTNode) -> None: i += 1 return self - async def _get_state(self) -> State: + async def _get_state_async(self) -> State: """Get the state of the model.""" await self._run() return self._interpreter.state - async def get(self, key: str) -> Any: + def _get_state(self) -> State: + """Get the state of the model.""" + coro = self._get_state_async() + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + else: + warnings.warn( + "Synchronous access to model state from an async context is ill-advised...", + stacklevel=2 + ) + # We're already in an async loop, so we have to run the coroutine in a nested event loop. + # TODO: consider raising an exception (sync guidance function called in async context) + result = None + exception = None + event = threading.Event() + def run(): + nonlocal result, exception + try: + result = asyncio.run(coro) + except Exception as ex: + exception = ex + finally: + event.set() + thread = threading.Thread(target=run) + thread.start() + event.wait() + thread.join() + if exception is not None: + raise exception + return result + + async def get_async(self, key: str) -> Any: try: - captures = (await self._get_state()).captures[key] + captures = (await self._get_state_async()).captures[key] except KeyError: raise KeyError(f"Model does not contain the variable '{key}'") if isinstance(captures, list): @@ -387,28 +362,13 @@ async def get(self, key: str) -> Any: else: return captures["value"] - def __getitem__(self, key): - raise TypeError( - "AsyncModel does not support __getitem__. Use the async get() method instead." - ) - - async def to_string(self) -> str: + async def to_string_async(self) -> str: """Get the string representation of the model.""" - return str(await self._get_state()) - - def __str__(self) -> str: - raise TypeError( - "AsyncModel does not support __str__. Use the async to_string() method instead." - ) + return str(await self._get_state_async()) - async def length(self) -> int: + async def length_async(self) -> int: """Get the length of the model.""" - return len(await self.to_string()) - - def __len__(self): - raise TypeError( - "AsyncModel does not support __len__. Use the async length() method instead." - ) + return len(await self.to_string_async()) class ModelStream: def __init__( diff --git a/guidance/models/_openai.py b/guidance/models/_openai.py index ac3e79b55..4cdaf439b 100644 --- a/guidance/models/_openai.py +++ b/guidance/models/_openai.py @@ -1,7 +1,7 @@ import base64 import wave from io import BytesIO -from typing import TYPE_CHECKING, Iterable, AsyncIterable, Literal, Optional, Union +from typing import TYPE_CHECKING, AsyncIterable, Literal, Optional, Union from copy import deepcopy from pydantic import BaseModel, Discriminator, Field, TypeAdapter @@ -22,7 +22,7 @@ from .._utils import bytes_from from ..trace import ImageOutput, OutputAttr, TextOutput from ..trace._trace import AudioOutput -from ._base import AsyncInterpreter, Interpreter, AsyncModel, Model, State +from ._base import Interpreter, Model, State if TYPE_CHECKING: from openai.types.chat import ChatCompletionChunk @@ -151,209 +151,6 @@ def __str__(self) -> str: class OpenAIInterpreter(Interpreter[OpenAIState]): log_probs: bool = True - def __init__( - self, - model: str, - api_key: Optional[str] = None, - **kwargs, - ): - try: - import openai - except ImportError: - raise Exception( - "Please install the openai package version >= 1 using `pip install openai -U` in order to use guidance.models.OpenAI!" - ) - self.state = OpenAIState() - self.model = model - self.client = openai.OpenAI(api_key=api_key, **kwargs) - - def __deepcopy__(self, memo): - """Custom deepcopy to ensure client is not copied.""" - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - if k == "client": - # Don't copy the client - setattr(result, k, v) - else: - setattr(result, k, deepcopy(v, memo)) - return result - - def run(self, node: ASTNode, **kwargs) -> Iterable[OutputAttr]: - if not isinstance(node, RoleStart) and self.state.active_role is None: - raise ValueError( - "OpenAI models require an active role (e.g. use `with assistant(): ...`)" - ) - return super().run(node, **kwargs) - - def role_start(self, node: RoleStart, **kwargs) -> Iterable[OutputAttr]: - self.state.active_role = node.role - # TODO: drop this and yield nothing. We need to add this for now as a workaround for the - # fact that current vis code assumes that there is actually a role start message - yield TextOutput(value=get_role_start(node.role), is_input=True) - - def role_end(self, node: RoleEnd, **kwargs) -> Iterable[OutputAttr]: - self.state.messages.append(self.state.get_active_message()) - self.state.audio = None - self.state.content = [] - self.state.active_role = None - yield from () - - def text(self, node: LiteralNode, **kwargs) -> Iterable[OutputAttr]: - self.state.apply_text(node.value) - yield TextOutput(value=node.value, is_input=True) - - def rule(self, node: RuleNode, **kwargs) -> Iterable[OutputAttr]: - if node.stop: - raise ValueError("Stop condition not yet supported for OpenAI") - if node.suffix: - raise ValueError("Suffix not yet supported for OpenAI") - if node.stop_capture: - raise ValueError("Save stop text not yet supported for OpenAI") - - kwargs = kwargs.copy() - if node.temperature: - kwargs["temperature"] = node.temperature - if node.max_tokens: - kwargs["max_tokens"] = node.max_tokens - - chunks = self.run(node.value, **kwargs) - if node.capture: - buffered_text = "" - for chunk in chunks: - # TODO: this isinstance check is pretty darn fragile. - # ~there must be a better way~ - if isinstance(chunk, TextOutput): - buffered_text += chunk.value - yield chunk - yield self.state.apply_capture( - name=node.capture, - value=buffered_text, - log_prob=1, # TODO - is_append=node.list_append, - ) - else: - yield from chunks - - def regex(self, node: RegexNode, **kwargs) -> Iterable[OutputAttr]: - if node.regex is not None: - raise ValueError("Regex not yet supported for OpenAI") - # We're in unconstrained mode now. - return self._run(**kwargs) - - def json(self, node: JsonNode, **kwargs) -> Iterable[OutputAttr]: - return self._run( - response_format={ - "type": "json_schema", - "json_schema": { - "name": "json_schema", # TODO? - "schema": node.schema, - "strict": True, - }, - }, - **kwargs, - ) - - def _run(self, **kwargs) -> Iterable[OutputAttr]: - if self.state.active_role is None: - # Should never happen? - raise ValueError( - "OpenAI models require chat blocks (e.g. use `with assistant(): ...`)" - ) - if self.state.active_role != "assistant": - raise ValueError( - "OpenAI models can only generate as the assistant (i.e. inside of `with assistant(): ...`)" - ) - if self.state.content: - raise ValueError( - f"OpenAI models do not support pre-filled assistant messages: got data {self.state.content}." - ) - - with self.client.chat.completions.create( - model=self.model, - messages=TypeAdapter(list[Message]).dump_python(self.state.messages), # type: ignore[arg-type] - logprobs=self.log_probs, - stream=True, - **kwargs, - ) as chunks: - yield from self._handle_stream(chunks) - - def _handle_stream( - self, chunks: Iterable["ChatCompletionChunk"] - ) -> Iterable[OutputAttr]: - audio: Optional[AssistantAudio] = None - for chunk in chunks: - choice = chunk.choices[0] - delta = choice.delta - if delta.content is not None: - assert audio is None - content = delta.content - if len(content) == 0: - continue - self.state.apply_text(content) - if choice.logprobs is not None: - # TODO: actually get tokens from this and be less lazy - prob = 2.718 ** choice.logprobs.content[0].logprob - else: - prob = float("nan") - yield TextOutput(value=delta.content, is_generated=True, prob=prob) - elif getattr(delta, "audio", None) is not None: - transcript_chunk: Optional[str] = None - if audio is None: - assert delta.audio.get("id") is not None - audio = AssistantAudio( - id=delta.audio["id"], - expires_at=delta.audio.get("expires_at", 0), # ? - transcript=delta.audio.get("transcript", ""), - data=delta.audio.get("data", ""), - ) - transcript_chunk = delta.audio.get("transcript") - else: - assert delta.audio.get("id") is None or delta.audio["id"] == audio.id - if delta.audio.get("data") is not None: - audio.data += delta.audio["data"] - if delta.audio.get("transcript") is not None: - audio.transcript += delta.audio["transcript"] - transcript_chunk = delta.audio["transcript"] - if delta.audio.get("expires_at") is not None: - assert audio.expires_at == 0 - audio.expires_at = delta.audio["expires_at"] - if transcript_chunk is not None: - # Why not give the users some transcript? :) - yield TextOutput( - value=delta.audio["transcript"], - is_generated=True, - ) - elif delta.function_call is not None: - raise NotImplementedError("Function calling not yet supported for OpenAI") - elif delta.tool_calls is not None: - raise NotImplementedError("Tool calling not yet supported for OpenAI") - elif delta.refusal is not None: - raise ValueError(f"OpenAI refused the request: {delta.refusal}") - - if choice.finish_reason is not None: - break - - if audio is not None: - assert self.state.audio is None - self.state.audio = audio - # Create an in-memory WAV file - wav_buffer = BytesIO() - with wave.open(wav_buffer, "wb") as wav_file: - wav_file.setnchannels(1) - wav_file.setsampwidth(2) # PCM16 = 2 bytes per sample - wav_file.setframerate(22050) # A guess - wav_file.writeframes(base64.b64decode(audio.data)) - - # Get WAV bytes - wav_bytes = wav_buffer.getvalue() - yield AudioOutput(value=base64.b64encode(wav_bytes).decode(), is_input=False) - - -class AsyncOpenAIInterpreter(AsyncInterpreter[OpenAIState]): - log_probs: bool = True - def __init__( self, model: str, @@ -559,7 +356,7 @@ async def _handle_stream( class OpenAIImageInterpreter(OpenAIInterpreter): - def image_blob(self, node: ImageBlob, **kwargs) -> Iterable[OutputAttr]: + async def image_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAttr]: try: import PIL.Image except ImportError: @@ -581,7 +378,7 @@ def image_blob(self, node: ImageBlob, **kwargs) -> Iterable[OutputAttr]: ) yield ImageOutput(value=node.data, input=True) - def image_url(self, node: ImageUrl, **kwargs) -> Iterable[OutputAttr]: + async def image_url(self, node: ImageUrl, **kwargs) -> AsyncIterable[OutputAttr]: self.state.content.append({"type": "image_url", "image_url": {"url": node.url}}) image_bytes = bytes_from(node.url, allow_local=False) base64_string = base64.b64encode(image_bytes).decode("utf-8") @@ -591,7 +388,7 @@ def image_url(self, node: ImageUrl, **kwargs) -> Iterable[OutputAttr]: class OpenAIAudioInterpreter(OpenAIInterpreter): log_probs: bool = False - def audio_blob(self, node: ImageBlob, **kwargs) -> Iterable[OutputAttr]: + async def audio_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAttr]: format = "wav" # TODO: infer from node self.state.content.append( AudioContent( @@ -604,8 +401,8 @@ def audio_blob(self, node: ImageBlob, **kwargs) -> Iterable[OutputAttr]: ) yield AudioOutput(value=node.data, format=format, input=True) - def gen_audio(self, node: GenAudio, **kwargs) -> Iterable[OutputAttr]: - yield from self._run( + async def gen_audio(self, node: GenAudio, **kwargs) -> AsyncIterable[OutputAttr]: + return self._run( modalities=["text", "audio"], # Has to be both? audio={ "voice": node.kwargs.get("voice", "alloy"), @@ -648,32 +445,3 @@ def __init__( super().__init__( interpreter=interpreter_cls(model, api_key=api_key, **kwargs), echo=echo ) - - -class AsyncOpenAI(AsyncModel): - def __init__( - self, - model: str, - echo: bool = True, - api_key: Optional[str] = None, - **kwargs, - ): - """Build a new OpenAI model object that represents a model in a given state. - - Parameters - ---------- - model : str - The name of the OpenAI model to use (e.g. gpt-4o-mini). - echo : bool - If true the final result of creating this model state will be displayed (as HTML in a notebook). - api_key : None or str - The OpenAI API key to use for remote requests, passed directly to the `openai.OpenAI` constructor. - - **kwargs : - All extra keyword arguments are passed directly to the `openai.OpenAI` constructor. Commonly used argument - names include `base_url` and `organization` - """ - super().__init__( - interpreter=AsyncOpenAIInterpreter(model, api_key=api_key, **kwargs), - echo=echo, - ) \ No newline at end of file From fbbe7224a19212762ce4c39ccf66778fda2db685 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 15 Apr 2025 11:02:33 -0700 Subject: [PATCH 05/56] make everything way more lazy --- guidance/_ast.py | 47 ++++++++++++-- guidance/models/_base/_interpreter.py | 23 ++++++- guidance/models/_base/_model.py | 90 +++++++++------------------ guidance/models/_openai.py | 8 +-- 4 files changed, 98 insertions(+), 70 deletions(-) diff --git a/guidance/_ast.py b/guidance/_ast.py index 886f7d334..477a9ecc9 100644 --- a/guidance/_ast.py +++ b/guidance/_ast.py @@ -116,13 +116,13 @@ def __call__(self, model): return model def __add__(self, other): - if not isinstance(other, (str, GrammarNode, Function)): + if not isinstance(other, (str, ASTNode, Function)): return NotImplemented if isinstance(other, str): other = _parse_tags(other) - if isinstance(other, GrammarNode) and other.is_null: + if isinstance(other, ASTNode) and other.is_null: return self def __add__(model): @@ -131,13 +131,13 @@ def __add__(model): return Function(__add__, [], {}) def __radd__(self, other): - if not isinstance(other, (str, GrammarNode, Function)): + if not isinstance(other, (str, ASTNode, Function)): return NotImplemented if isinstance(other, str): other = _parse_tags(other) - if isinstance(other, GrammarNode) and other.is_null: + if isinstance(other, ASTNode) and other.is_null: return self def __radd__(model): @@ -157,6 +157,45 @@ def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: def simplify(self) -> "ASTNode": return self + @property + def is_null(self) -> bool: + return False + + def __add__(self, other): + if isinstance(other, str): + other = _parse_tags(other) + + if isinstance(other, ASTNode): + return Concatenate((self, other)) + + return NotImplemented + + def __radd__(self, other): + if isinstance(other, str): + other = _parse_tags(other) + + if isinstance(other, ASTNode): + return Concatenate((other, self)) + + return NotImplemented + + @classmethod + def null(cls) -> "ASTNode": + return Concatenate(()) + +@dataclass(frozen=True) +class Concatenate(ASTNode): + nodes: tuple[ASTNode, ...] + + def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + return interpreter.concatenate(self, **kwargs) + + def __iter__(self) -> Iterable[ASTNode]: + for node in self.nodes: + if isinstance(node, Concatenate): + yield from node + else: + yield node @dataclass class RoleStart(ASTNode): diff --git a/guidance/models/_base/_interpreter.py b/guidance/models/_base/_interpreter.py index d150bef4c..dbda28d7d 100644 --- a/guidance/models/_base/_interpreter.py +++ b/guidance/models/_base/_interpreter.py @@ -1,10 +1,11 @@ import base64 from abc import ABC -from typing import Generic, Iterable, AsyncIterable, TypeVar, Union +from typing import Generic, Iterable, AsyncIterable, TypeVar, Union, Optional from ..._ast import ( ASTNode, AudioBlob, + Concatenate, GenAudio, GrammarNode, ImageBlob, @@ -35,6 +36,26 @@ def __init__(self, state: S): def run(self, node: ASTNode, **kwargs) -> AsyncIterable[OutputAttr]: return node.simplify()._run(self, **kwargs) + async def concatenate(self, node: Concatenate, **kwargs) -> AsyncIterable[OutputAttr]: + buffer: Optional[GrammarNode] = None + for child in node: + assert not isinstance(child, Concatenate) # iter should be flat + if isinstance(child, GrammarNode): + if buffer is None: + buffer = child + else: + buffer = buffer + child + else: + if buffer is not None: + async for attr in self.run(buffer, **kwargs): + yield attr + buffer = None + async for attr in self.run(child, **kwargs): + yield attr + if buffer is not None: + async for attr in self.run(buffer, **kwargs): + yield attr + def _role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: if self.state.active_role is not None: raise ValueError( diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index d5e8ca960..b24355d0f 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -5,13 +5,14 @@ import asyncio from contextvars import ContextVar, copy_context from copy import deepcopy -from typing import TYPE_CHECKING, Any, Iterator, Optional, TypeVar, Union, Generic +from typing import TYPE_CHECKING, Any, Iterator, Optional, TypeVar, Union import warnings from typing_extensions import Self from ..._ast import ( ASTNode, + Concatenate, Function, GenAudio, GrammarNode, @@ -67,7 +68,7 @@ def __init__( ) -> None: self.echo = echo self._interpreter = interpreter - self._pending: tuple[ASTNode, ...] = () + self._pending: Union[None, ASTNode, Function] = None self._active_blocks: tuple[Block, ...] = () self._parent: Optional["Model"] = None @@ -94,6 +95,12 @@ def _update_trace_node( # ) pass + def _add_to_pending(self, item: Union[ASTNode, Function]) -> None: + if self._pending is None: + self._pending = item + else: + self._pending += item + def __add__(self, other: Union[str, Function, ASTNode]) -> Self: self = self.copy() self._apply_blocks() @@ -101,10 +108,8 @@ def __add__(self, other: Union[str, Function, ASTNode]) -> Self: if other == "": return self other = _parse_tags(other) - if isinstance(other, Function): - return other(self) - if isinstance(other, ASTNode): - self._pending += (other,) + if isinstance(other, (ASTNode, Function)): + self._add_to_pending(other) return self return NotImplemented @@ -127,13 +132,9 @@ def _apply_blocks(self) -> None: closer = block.closer if isinstance(closer, str): closer = _parse_tags(closer) - if isinstance(closer, Function): - raise NotImplementedError( - "Stateful block opener/closer functions are not yet supported" - ) - self._pending += (closer,) + self._add_to_pending(closer) if block.name is not None: - self._pending += (CaptureEnd(name=block.name),) + self._add_to_pending(CaptureStart(name=block.name)) else: # Not closed, so keep it new_active_blocks.append(block) @@ -143,16 +144,12 @@ def _apply_blocks(self) -> None: if block not in self._active_blocks: new_active_blocks.append(block) if block.name is not None: - self._pending += (CaptureStart(name=block.name),) + self._add_to_pending(CaptureEnd(name=block.name)) if block.opener is not None: opener = block.opener if isinstance(opener, str): opener = _parse_tags(opener) - if isinstance(opener, Function): - raise NotImplementedError( - "Stateful block opener/closer functions are not yet supported" - ) - self._pending += (opener,) + self._add_to_pending(opener) self._active_blocks = tuple(new_active_blocks) def copy(self) -> Self: @@ -263,55 +260,26 @@ def __getattribute__(self, name): return super().__getattribute__(name) async def _run(self) -> None: - buffer: Optional[GrammarNode] = None - nodes = self._pending - self._pending = () - for node in nodes: - if isinstance(node, GrammarNode): - if buffer is None: - buffer = node - else: - buffer += node - else: - if buffer is not None: - await self._run_node(buffer) - buffer = None - await self._run_node(node) - if buffer is not None: - await self._run_node(buffer) + new_self = self.copy() + while isinstance(new_self._pending, Function): + func = new_self._pending + new_self._pending = None + new_self = func(new_self) + self.__dict__ = new_self.__dict__ # I guess + if self._pending is None: + return + + assert isinstance(self._pending, ASTNode) + node = self._pending + self._pending = None + await self._run_node(node) - async def _run_node(self, node: ASTNode) -> None: - if isinstance(node, RoleStart): - self._update_trace_node(self._id, self._parent_id, RoleOpenerInput(name=node.role)) - elif isinstance(node, RoleEnd): - self._update_trace_node(self._id, self._parent_id, RoleCloserInput(name=node.role)) - elif isinstance(node, LiteralNode): - self._update_trace_node(self._id, self._parent_id, LiteralInput(value=node.value)) - elif isinstance(node, ImageBlob): - self._update_trace_node(self._id, self._parent_id, ImageInput(value=node.data)) - elif isinstance(node, ImageUrl): - # TODO -- let's avoid downloading it here - pass - elif isinstance(node, GenAudio): - self._update_trace_node( - self._id, self._parent_id, AudioInput(value="") - ) # TODO -- what goes here? - else: - self._update_trace_node(self._id, self._parent_id, StatelessGuidanceInput(value=node)) - i = 0 + async def _run_node(self, node: ASTNode) -> None: async for output_attr in self._interpreter.run(node): - if i != 0: - # On the first iteration, we already have a fresh trace node - # TODO: should be allowed to associate multiple output_attrs with a single input node? - # TODO: put this responsibility on the client in the case that it breaks a single input - # node into multiple input nodes to be handled sequentially? - self._parent_id = self._id - self._id = _gen_id() self._update_trace_node(self._id, self._parent_id, output_attr) # Stream current model state self._send_to_event_queue() - i += 1 return self async def _get_state_async(self) -> State: diff --git a/guidance/models/_openai.py b/guidance/models/_openai.py index 4cdaf439b..5ac1aa33f 100644 --- a/guidance/models/_openai.py +++ b/guidance/models/_openai.py @@ -181,10 +181,10 @@ def __deepcopy__(self, memo): return result def run(self, node: ASTNode, **kwargs) -> AsyncIterable[OutputAttr]: - if not isinstance(node, RoleStart) and self.state.active_role is None: - raise ValueError( - "OpenAI models require an active role (e.g. use `with assistant(): ...`)" - ) + # if not isinstance(node, RoleStart) and self.state.active_role is None: + # raise ValueError( + # "OpenAI models require an active role (e.g. use `with assistant(): ...`)" + # ) return super().run(node, **kwargs) async def role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: From 2df1c39015b1e047b3aee5c6f236725622c0e750 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 15 Apr 2025 11:16:32 -0700 Subject: [PATCH 06/56] AsyncFunction --- guidance/_ast.py | 48 +++++++++++++++++++++++++++++++++ guidance/_guidance.py | 7 ++++- guidance/models/_base/_model.py | 9 +++++-- 3 files changed, 61 insertions(+), 3 deletions(-) diff --git a/guidance/_ast.py b/guidance/_ast.py index 477a9ecc9..b4d25dfd1 100644 --- a/guidance/_ast.py +++ b/guidance/_ast.py @@ -145,6 +145,54 @@ def __radd__(model): return Function(__radd__, [], {}) +@dataclass +class AsyncFunction(Tagged): + name: str = field(init=False) + f: Callable + args: tuple[Any, ...] + kwargs: dict[str, Any] + + def __post_init__(self): + self.name = self.f.__name__ + + async def __call__(self, model): + model = await self.f(model, *self.args, **self.kwargs) + if model is None: + raise Exception( + f"The guidance function `{self.f.__name__}` did not return a model object! You need to return an updated model object at the end of your guidance function." + ) + return model + + def __add__(self, other): + if not isinstance(other, (str, ASTNode, Function, AsyncFunction)): + return NotImplemented + + if isinstance(other, str): + other = _parse_tags(other) + + if isinstance(other, ASTNode) and other.is_null: + return self + + async def __add__(model): + return (await self(model)) + other + + return AsyncFunction(__add__, [], {}) + + def __radd__(self, other): + if not isinstance(other, (str, ASTNode, Function, AsyncFunction)): + return NotImplemented + + if isinstance(other, str): + other = _parse_tags(other) + + if isinstance(other, ASTNode) and other.is_null: + return self + + async def __radd__(model): + return await self(model + other) + + return AsyncFunction(__radd__, [], {}) + S = TypeVar("S", bound="State") R = TypeVar("R", bound=Union[Iterable[OutputAttr], AsyncIterable[OutputAttr]]) diff --git a/guidance/_guidance.py b/guidance/_guidance.py index e58d1699b..6155e88b2 100644 --- a/guidance/_guidance.py +++ b/guidance/_guidance.py @@ -8,7 +8,7 @@ from ._grammar import string -from ._ast import Function, RuleRefNode, RuleNode +from ._ast import AsyncFunction, Function, RuleRefNode, RuleNode from ._utils import strip_multiline_string_indents, make_weak_bound_method, signature_pop from .models import Model @@ -127,6 +127,11 @@ def _decorator(f, *, stateless, cache, model): @functools.wraps(f) def wrapped(*args, **kwargs): + from inspect import iscoroutinefunction + if iscoroutinefunction(f): + if stateless is True: + raise ValueError("Stateless functions cannot be async") + return AsyncFunction(f, args, kwargs) # make a stateless grammar if we can if stateless is True or ( diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index b24355d0f..d15d9677e 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -14,6 +14,7 @@ ASTNode, Concatenate, Function, + AsyncFunction, GenAudio, GrammarNode, ImageBlob, @@ -101,14 +102,14 @@ def _add_to_pending(self, item: Union[ASTNode, Function]) -> None: else: self._pending += item - def __add__(self, other: Union[str, Function, ASTNode]) -> Self: + def __add__(self, other: Union[str, Function, AsyncFunction, ASTNode]) -> Self: self = self.copy() self._apply_blocks() if isinstance(other, str): if other == "": return self other = _parse_tags(other) - if isinstance(other, (ASTNode, Function)): + if isinstance(other, (ASTNode, Function, AsyncFunction)): self._add_to_pending(other) return self return NotImplemented @@ -261,6 +262,10 @@ def __getattribute__(self, name): async def _run(self) -> None: new_self = self.copy() + while isinstance(new_self._pending, AsyncFunction): + func = new_self._pending + new_self._pending = None + new_self = await func(new_self) while isinstance(new_self._pending, Function): func = new_self._pending new_self._pending = None From 55ef91ea3d772e287fb951b36fe74f1d7d7aaa81 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 15 Apr 2025 11:48:34 -0700 Subject: [PATCH 07/56] subtle -- blocks need to be applied before running --- guidance/models/_base/_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index d15d9677e..013516d3b 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -262,6 +262,8 @@ def __getattribute__(self, name): async def _run(self) -> None: new_self = self.copy() + # may be some pending blocks + new_self._apply_blocks() while isinstance(new_self._pending, AsyncFunction): func = new_self._pending new_self._pending = None From 636e2baac57ae58f5599538a82dc9bc857ab776e Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 15 Apr 2025 12:53:38 -0700 Subject: [PATCH 08/56] make sure that async + sync function application doesn't run sync inside of async --- guidance/models/_base/_model.py | 93 ++++++++++++++++++++++----------- 1 file changed, 63 insertions(+), 30 deletions(-) diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index 013516d3b..a345bdc26 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -281,6 +281,33 @@ async def _run(self) -> None: self._pending = None await self._run_node(node) + def _run_sync(self) -> None: + new_self = self.copy() + # may be some pending blocks + new_self._apply_blocks() + if isinstance(new_self._pending, AsyncFunction): + async def async_part(): + nonlocal new_self + while isinstance(new_self._pending, AsyncFunction): + func = new_self._pending + new_self._pending = None + new_self = await func(new_self) + + # TODO: replace with our "safe" version + asyncio.run(async_part()) + while isinstance(new_self._pending, Function): + func = new_self._pending + new_self._pending = None + new_self = func(new_self) + self.__dict__ = new_self.__dict__ # I guess + if self._pending is None: + return + assert isinstance(self._pending, ASTNode) + node = self._pending + self._pending = None + + # TODO: replace with our "safe" version + return asyncio.run(self._run_node(node)) async def _run_node(self, node: ASTNode) -> None: async for output_attr in self._interpreter.run(node): @@ -296,36 +323,42 @@ async def _get_state_async(self) -> State: def _get_state(self) -> State: """Get the state of the model.""" - coro = self._get_state_async() - try: - asyncio.get_running_loop() - except RuntimeError: - return asyncio.run(coro) - else: - warnings.warn( - "Synchronous access to model state from an async context is ill-advised...", - stacklevel=2 - ) - # We're already in an async loop, so we have to run the coroutine in a nested event loop. - # TODO: consider raising an exception (sync guidance function called in async context) - result = None - exception = None - event = threading.Event() - def run(): - nonlocal result, exception - try: - result = asyncio.run(coro) - except Exception as ex: - exception = ex - finally: - event.set() - thread = threading.Thread(target=run) - thread.start() - event.wait() - thread.join() - if exception is not None: - raise exception - return result + self._run_sync() + return self._interpreter.state + + # coro = self._get_state_async() + # try: + # asyncio.get_running_loop() + # except RuntimeError: + # return asyncio.run(coro) + # else: + # raise RuntimeError( + # "Synchronous access to model state from an async context is punishable by death." + # ) + # warnings.warn( + # "Synchronous access to model state from an async context is ill-advised...", + # stacklevel=2 + # ) + # # We're already in an async loop, so we have to run the coroutine in a nested event loop. + # # TODO: consider raising an exception (sync guidance function called in async context) + # result = None + # exception = None + # event = threading.Event() + # def run(): + # nonlocal result, exception + # try: + # result = asyncio.run(coro) + # except Exception as ex: + # exception = ex + # finally: + # event.set() + # thread = threading.Thread(target=run) + # thread.start() + # event.wait() + # thread.join() + # if exception is not None: + # raise exception + # return result async def get_async(self, key: str) -> Any: try: From 032ef61b6d231416a4e607a4f6c7b4bec18dd033 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 15 Apr 2025 14:39:34 -0700 Subject: [PATCH 09/56] centralize run_async_maybe_in_thread logic --- guidance/models/_base/_model.py | 82 +++++++++++++++++---------------- 1 file changed, 42 insertions(+), 40 deletions(-) diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index a345bdc26..c230e620f 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -292,9 +292,7 @@ async def async_part(): func = new_self._pending new_self._pending = None new_self = await func(new_self) - - # TODO: replace with our "safe" version - asyncio.run(async_part()) + run_async_maybe_in_thread(async_part()) while isinstance(new_self._pending, Function): func = new_self._pending new_self._pending = None @@ -305,9 +303,7 @@ async def async_part(): assert isinstance(self._pending, ASTNode) node = self._pending self._pending = None - - # TODO: replace with our "safe" version - return asyncio.run(self._run_node(node)) + run_async_maybe_in_thread(self._run_node(node)) async def _run_node(self, node: ASTNode) -> None: async for output_attr in self._interpreter.run(node): @@ -326,40 +322,6 @@ def _get_state(self) -> State: self._run_sync() return self._interpreter.state - # coro = self._get_state_async() - # try: - # asyncio.get_running_loop() - # except RuntimeError: - # return asyncio.run(coro) - # else: - # raise RuntimeError( - # "Synchronous access to model state from an async context is punishable by death." - # ) - # warnings.warn( - # "Synchronous access to model state from an async context is ill-advised...", - # stacklevel=2 - # ) - # # We're already in an async loop, so we have to run the coroutine in a nested event loop. - # # TODO: consider raising an exception (sync guidance function called in async context) - # result = None - # exception = None - # event = threading.Event() - # def run(): - # nonlocal result, exception - # try: - # result = asyncio.run(coro) - # except Exception as ex: - # exception = ex - # finally: - # event.set() - # thread = threading.Thread(target=run) - # thread.start() - # event.wait() - # thread.join() - # if exception is not None: - # raise exception - # return result - async def get_async(self, key: str) -> Any: try: captures = (await self._get_state_async()).captures[key] @@ -449,3 +411,43 @@ def target(ctx): # Reset the event queues context variable _event_queues.reset(token) + +def run_async_maybe_in_thread( + coro +): + """ + Run a coroutine in a thread if not already in an async context. + """ + try: + asyncio.get_running_loop() + except RuntimeError: + # Not in an async context, run the coroutine in a thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop.run_until_complete(coro) + else: + warnings.warn( + "Synchronous access to model state from an async context is ill-advised...", + stacklevel=2 + ) + # We're already in an async loop, so we have to run the coroutine in a nested event loop. + # TODO: consider raising an exception (sync guidance function called in async context) + # TODO: consider using some global thread and call asyncio.run_coroutine_threadsafe + result = None + exception = None + event = threading.Event() + def run(): + nonlocal result, exception + try: + result = asyncio.run(coro) + except Exception as ex: + exception = ex + finally: + event.set() + thread = threading.Thread(target=run) + thread.start() + event.wait() + thread.join() + if exception is not None: + raise exception + return result From 3f3ac5d5ef1be8e9a3d236103f981bf6440fec0d Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 15 Apr 2025 15:02:28 -0700 Subject: [PATCH 10/56] make engine interpreter async (well, as close as we can get right now) --- guidance/models/_base/_interpreter.py | 1 + guidance/models/_engine/_interpreter.py | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/guidance/models/_base/_interpreter.py b/guidance/models/_base/_interpreter.py index dbda28d7d..136878313 100644 --- a/guidance/models/_base/_interpreter.py +++ b/guidance/models/_base/_interpreter.py @@ -83,6 +83,7 @@ def image_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAttr]: raise UnsupportedNodeError(interpreter=self, node=node) def image_url(self, node: ImageUrl, **kwargs) -> AsyncIterable[OutputAttr]: + # TODO: we should be using something like httpx to fetch the image image_bytes = bytes_from(node.url, allow_local=False) base64_string = base64.b64encode(image_bytes).decode("utf-8") return self.image_blob(ImageBlob(data=base64_string), **kwargs) diff --git a/guidance/models/_engine/_interpreter.py b/guidance/models/_engine/_interpreter.py index 8fdf599dd..cc95ff788 100644 --- a/guidance/models/_engine/_interpreter.py +++ b/guidance/models/_engine/_interpreter.py @@ -1,6 +1,6 @@ from base64 import b64decode from io import BytesIO -from typing import Iterator +from typing import AsyncIterable from copy import deepcopy from ..._ast import GrammarNode, ImageBlob, LiteralNode, RoleEnd, RoleStart @@ -39,21 +39,21 @@ def get_role_end(self, role: str) -> str: raise ValueError("Cannot use roles without a chat template") return self.chat_template.get_role_end(role) - def role_start(self, node: RoleStart, **kwargs) -> Iterator[OutputAttr]: + async def role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: self.state.active_role = node.role # TODO: mark these as special tokens..? - yield from self.run(LiteralNode(value=self.get_role_start(node.role)), **kwargs) + return self.run(LiteralNode(value=self.get_role_start(node.role)), **kwargs) - def role_end(self, node: RoleEnd, **kwargs) -> Iterator[OutputAttr]: + async def role_end(self, node: RoleEnd, **kwargs) -> AsyncIterable[OutputAttr]: self.state.active_role = None # TODO: mark these as special tokens..? - yield from self.run(LiteralNode(value=self.get_role_end(node.role)), **kwargs) + return self.run(LiteralNode(value=self.get_role_end(node.role)), **kwargs) - def text(self, node: LiteralNode, **kwargs) -> Iterator[OutputAttr]: + async def text(self, node: LiteralNode, **kwargs) -> AsyncIterable[OutputAttr]: self.state.prompt += node.value yield TextOutput(value=node.value, is_input=True) - def grammar(self, node: GrammarNode, **kwargs) -> Iterator[OutputAttr]: + async def grammar(self, node: GrammarNode, **kwargs) -> AsyncIterable[OutputAttr]: engine_gen = self.engine( state=self.state, grammar=node.ll_grammar(), @@ -62,6 +62,7 @@ def grammar(self, node: GrammarNode, **kwargs) -> Iterator[OutputAttr]: ) delayed_bytes = b"" + # TODO: this should be async some day for chunk in engine_gen: new_bytes = chunk.new_bytes new_text, delayed_bytes = partial_decode(new_bytes) @@ -112,7 +113,7 @@ def grammar(self, node: GrammarNode, **kwargs) -> Iterator[OutputAttr]: class Llama3VisionInterpreter(EngineInterpreter): - def image_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: + async def image_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAttr]: try: import PIL.Image except ImportError: @@ -129,7 +130,7 @@ def image_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: class Phi3VisionInterpreter(EngineInterpreter): - def image_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: + async def image_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAttr]: try: import PIL.Image except ImportError: From d7d46ee9a2821319fcb5bc0969c32be40053c702 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 15 Apr 2025 15:04:40 -0700 Subject: [PATCH 11/56] fix engine interpreter role start/end --- guidance/models/_engine/_interpreter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/guidance/models/_engine/_interpreter.py b/guidance/models/_engine/_interpreter.py index cc95ff788..b27563b25 100644 --- a/guidance/models/_engine/_interpreter.py +++ b/guidance/models/_engine/_interpreter.py @@ -39,12 +39,12 @@ def get_role_end(self, role: str) -> str: raise ValueError("Cannot use roles without a chat template") return self.chat_template.get_role_end(role) - async def role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: + def role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: self.state.active_role = node.role # TODO: mark these as special tokens..? return self.run(LiteralNode(value=self.get_role_start(node.role)), **kwargs) - async def role_end(self, node: RoleEnd, **kwargs) -> AsyncIterable[OutputAttr]: + def role_end(self, node: RoleEnd, **kwargs) -> AsyncIterable[OutputAttr]: self.state.active_role = None # TODO: mark these as special tokens..? return self.run(LiteralNode(value=self.get_role_end(node.role)), **kwargs) From d17110689e1316dd46d4e40881c873e64227778e Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 15 Apr 2025 15:07:49 -0700 Subject: [PATCH 12/56] change test_text_closer --- tests/unit/library/test_block.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/library/test_block.py b/tests/unit/library/test_block.py index 0726ef192..c86120a7c 100644 --- a/tests/unit/library/test_block.py +++ b/tests/unit/library/test_block.py @@ -10,12 +10,12 @@ def test_text_opener(): def test_text_closer(): - # NOTE(nopdive): Behavioral change, no longer need closer for str call. model = models.Mock("a") model += "" with block(closer="close text"): model += regex(r".") - assert str(model) == "a" + assert str(model) == "a" + assert str(model) == "aclose text" def test_grammar_opener(): From dea8020585062076aef10dfed5fcc8926898d8d6 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Wed, 16 Apr 2025 15:35:31 -0700 Subject: [PATCH 13/56] add some tests --- tests/unit/test_async.py | 111 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 tests/unit/test_async.py diff --git a/tests/unit/test_async.py b/tests/unit/test_async.py new file mode 100644 index 000000000..017c479fc --- /dev/null +++ b/tests/unit/test_async.py @@ -0,0 +1,111 @@ +import asyncio +import pytest +from guidance import * + +@guidance +def sync_func(lm: models.Model): + lm += select(["a", "b"], name="choice") + if lm.get("choice") == "a": + lm += "lpha" + else: + lm += "eta" + return lm + +@guidance +async def async_func(lm: models.Model): + lm += select(["a", "b"], name="choice") + if (await lm.get_async("choice")) == "a": + lm += "lpha" + else: + lm += "eta" + return lm + +@guidance +async def async_func_with_sync_accessor(lm: models.Model): + lm += select(["a", "b"], name="choice") + if lm.get("choice") == "a": + lm += "lpha" + else: + lm += "eta" + return lm + +@guidance +def sync_func_calling_async_func(lm: models.Model): + lm += async_func() + lm += select(["a", "b"], name="choice") + if lm.get("choice") == "a": + lm += "lpha" + else: + lm += "eta" + return lm + +@guidance +async def async_func_calling_sync_func(lm: models.Model): + lm += sync_func() + lm += select(["a", "b"], name="choice") + if (await lm.get_async("choice")) == "a": + lm += "lpha" + else: + lm += "eta" + return lm + +@guidance +def sync_then_async(lm: models.Model): + lm += sync_func() + lm += async_func() + return lm + +@guidance +def async_then_sync(lm: models.Model): + lm += async_func() + lm += sync_func() + return lm + +def run(gfunc, sync: bool) -> str: + lm = models.Mock() + lm += gfunc() + if sync: + s = str(lm) + else: + s = asyncio.run(lm.to_string_async()) + return s + +@pytest.mark.parametrize( + "sync", + [ + True, + False, + ], +) +@pytest.mark.parametrize( + "gfunc, expected", + [ + (sync_func, {"alpha", "beta"}), + (async_func, {"alpha", "beta"}), + (sync_then_async, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), + (async_then_sync, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), + (async_func_calling_sync_func, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), + (sync_func_calling_async_func, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), + ], +) +def test_async(gfunc, expected, sync): + s = run(gfunc, sync) + assert s in expected + +@pytest.mark.parametrize( + "sync", + [ + True, + False, + ], +) +def test_async_with_sync_accessor(sync): + s = run(async_func_with_sync_accessor, sync) + assert s in {"alphabeta", "betaalpha", "alphaalpha", "betabeta"} + +def test_sync_accessor_in_foreign_event_loop(): + async def main(): + lm = models.Mock() + lm += sync_func() + assert str(lm) in {"alpha", "beta"} + asyncio.run(main()) From 7b3aae66fc01a57a969acf3877e0a9acae6eb6b0 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Wed, 16 Apr 2025 15:39:20 -0700 Subject: [PATCH 14/56] fix sync/async eval loop (recall -- async can be inside of sync) --- guidance/models/_base/_model.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index c230e620f..f699028a2 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -264,14 +264,14 @@ async def _run(self) -> None: new_self = self.copy() # may be some pending blocks new_self._apply_blocks() - while isinstance(new_self._pending, AsyncFunction): + while isinstance(new_self._pending, (Function, AsyncFunction)): func = new_self._pending new_self._pending = None - new_self = await func(new_self) - while isinstance(new_self._pending, Function): - func = new_self._pending - new_self._pending = None - new_self = func(new_self) + if isinstance(func, AsyncFunction): + new_self = await func(new_self) + else: + # TODO: maybe run in thread to avoid blocking? + new_self = func(new_self) self.__dict__ = new_self.__dict__ # I guess if self._pending is None: return @@ -285,18 +285,14 @@ def _run_sync(self) -> None: new_self = self.copy() # may be some pending blocks new_self._apply_blocks() - if isinstance(new_self._pending, AsyncFunction): - async def async_part(): - nonlocal new_self - while isinstance(new_self._pending, AsyncFunction): - func = new_self._pending - new_self._pending = None - new_self = await func(new_self) - run_async_maybe_in_thread(async_part()) - while isinstance(new_self._pending, Function): + while isinstance(new_self._pending, (Function, AsyncFunction)): func = new_self._pending new_self._pending = None - new_self = func(new_self) + if isinstance(func, AsyncFunction): + # TODO: share a bg thread + new_self = run_async_maybe_in_thread(func(new_self)) + else: + new_self = func(new_self) self.__dict__ = new_self.__dict__ # I guess if self._pending is None: return From 7db5363ba6257d489de72de21487ba9c858f5d10 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Wed, 16 Apr 2025 17:00:06 -0700 Subject: [PATCH 15/56] fix test --- tests/unit/test_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_async.py b/tests/unit/test_async.py index 017c479fc..22313814d 100644 --- a/tests/unit/test_async.py +++ b/tests/unit/test_async.py @@ -101,7 +101,7 @@ def test_async(gfunc, expected, sync): ) def test_async_with_sync_accessor(sync): s = run(async_func_with_sync_accessor, sync) - assert s in {"alphabeta", "betaalpha", "alphaalpha", "betabeta"} + assert s in {"alpha", "beta"} def test_sync_accessor_in_foreign_event_loop(): async def main(): From bb931f392ce88bb1be6e62f2707edab5e3380a5a Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Wed, 16 Apr 2025 11:00:13 -0700 Subject: [PATCH 16/56] make our async re-entrant with greenlets --- guidance/_bridge.py | 101 ++++++++++++++++++++++++++++++++ guidance/models/_base/_model.py | 67 +++------------------ 2 files changed, 108 insertions(+), 60 deletions(-) create mode 100644 guidance/_bridge.py diff --git a/guidance/_bridge.py b/guidance/_bridge.py new file mode 100644 index 000000000..e41670cb5 --- /dev/null +++ b/guidance/_bridge.py @@ -0,0 +1,101 @@ +""" +Heavily inspired by (and largely stolen from)`greenletio`: +https://github.com/miguelgrinberg/greenletio +""" + +from greenlet import greenlet, getcurrent +import sys +import asyncio +import threading +import warnings +from typing import Awaitable, TypeVar, Callable, Union, Optional, cast +from typing_extensions import ParamSpec, Never + +P = ParamSpec("P") +T = TypeVar("T") + +class AwaitException(Exception): + """Exception raised when a coroutine is awaited in a non-greenlet context.""" + pass + +def await_(coro: Awaitable[T]) -> T: + """ + Sends a coroutine to the parent greenlet. The parent greenlet should either + 1. await the coroutine in an async function + 2. await_ the coroutine in a sync function, punting the problem further up + + If there is no parent greenlet, we'll call the usual asyncio.run() to run the + coroutine. + + If this fails due to an existing event loop, that means that the caller is a + foreign async function (not one of "ours"), and they need to use one of our + async entry-points to run the coroutine with await. + """ + parent_gl = getcurrent().parent + if parent_gl is None: + return run_async_maybe_in_thread(coro) + return cast(T, parent_gl.switch(coro)) + +def async_(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]: + """ + Decorator to convert a synchronous function into an asynchronous one. + + If `await_` is called somewhere down the call stack, we are prepared to + receive it and take responsibility for awaiting the coroutine. + """ + async def decorator(*args: P.args, **kwargs: P.kwargs) -> T: + gl = greenlet(fn) + coro: Union[T, Awaitable[T]] = gl.switch(*args, **kwargs) + while gl: + coro = cast(Awaitable[T], coro) + try: + result = await coro + except: # noqa: E722 + # this catches exceptions from async functions awaited in + # sync code, and re-raises them in the greenlet + coro = cast(Never, gl.throw(*sys.exc_info())) + else: + coro = cast(Union[T, Awaitable[T]], gl.switch(result)) + return coro + return decorator + + +def run_async_maybe_in_thread( + coro: Awaitable[T] +) -> T: + """ + Run a coroutine in a thread if not already in an async context. + """ + try: + asyncio.get_running_loop() + except RuntimeError: + # Not in an async context, run the coroutine in a thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop.run_until_complete(coro) + else: + warnings.warn( + "Synchronous access to model state from an async context is ill-advised...", + stacklevel=2 + ) + # We're already in an async loop, so we have to run the coroutine in a nested event loop. + # TODO: consider raising an exception (sync guidance function called in async context) + # TODO: consider using some global thread and call asyncio.run_coroutine_threadsafe + result: Optional[T] = None + exception: Optional[Exception] = None + event = threading.Event() + def run(): + nonlocal result, exception + try: + result = cast(T, asyncio.run(coro)) + except Exception as ex: + exception = ex + finally: + event.set() + thread = threading.Thread(target=run) + thread.start() + event.wait() + thread.join() + if exception is not None: + raise exception + return cast(T, result) diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index f699028a2..9a8e5a2f4 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -7,6 +7,7 @@ from copy import deepcopy from typing import TYPE_CHECKING, Any, Iterator, Optional, TypeVar, Union import warnings +from ..._bridge import async_, await_ from typing_extensions import Self @@ -270,8 +271,11 @@ async def _run(self) -> None: if isinstance(func, AsyncFunction): new_self = await func(new_self) else: - # TODO: maybe run in thread to avoid blocking? - new_self = func(new_self) + # If someone awaits us directly (i.e. we're not below an `await_`), + # we need to wrap the sync part in `async_` to avoid blocking our caller's + # event loop. + # Otherwise, this is effectively equivalent to func(new_self) + new_self = await async_(func)(new_self) self.__dict__ = new_self.__dict__ # I guess if self._pending is None: return @@ -282,24 +286,7 @@ async def _run(self) -> None: await self._run_node(node) def _run_sync(self) -> None: - new_self = self.copy() - # may be some pending blocks - new_self._apply_blocks() - while isinstance(new_self._pending, (Function, AsyncFunction)): - func = new_self._pending - new_self._pending = None - if isinstance(func, AsyncFunction): - # TODO: share a bg thread - new_self = run_async_maybe_in_thread(func(new_self)) - else: - new_self = func(new_self) - self.__dict__ = new_self.__dict__ # I guess - if self._pending is None: - return - assert isinstance(self._pending, ASTNode) - node = self._pending - self._pending = None - run_async_maybe_in_thread(self._run_node(node)) + await_(self._run()) async def _run_node(self, node: ASTNode) -> None: async for output_attr in self._interpreter.run(node): @@ -407,43 +394,3 @@ def target(ctx): # Reset the event queues context variable _event_queues.reset(token) - -def run_async_maybe_in_thread( - coro -): - """ - Run a coroutine in a thread if not already in an async context. - """ - try: - asyncio.get_running_loop() - except RuntimeError: - # Not in an async context, run the coroutine in a thread - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop.run_until_complete(coro) - else: - warnings.warn( - "Synchronous access to model state from an async context is ill-advised...", - stacklevel=2 - ) - # We're already in an async loop, so we have to run the coroutine in a nested event loop. - # TODO: consider raising an exception (sync guidance function called in async context) - # TODO: consider using some global thread and call asyncio.run_coroutine_threadsafe - result = None - exception = None - event = threading.Event() - def run(): - nonlocal result, exception - try: - result = asyncio.run(coro) - except Exception as ex: - exception = ex - finally: - event.set() - thread = threading.Thread(target=run) - thread.start() - event.wait() - thread.join() - if exception is not None: - raise exception - return result From be03577fb669e4baf9a6322840b3295ca67cc82b Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Thu, 17 Apr 2025 10:47:28 -0700 Subject: [PATCH 17/56] doc --- guidance/_bridge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guidance/_bridge.py b/guidance/_bridge.py index e41670cb5..1fa8eb7eb 100644 --- a/guidance/_bridge.py +++ b/guidance/_bridge.py @@ -1,5 +1,5 @@ """ -Heavily inspired by (and largely stolen from)`greenletio`: +Heavily inspired by (read: largely stolen from) https://github.com/miguelgrinberg/greenletio """ From 763e10b6b67029e2cff5188cb881cb86895680b3 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Thu, 17 Apr 2025 10:47:51 -0700 Subject: [PATCH 18/56] fix wrong comment and use higher-level asyncio run --- guidance/_bridge.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/guidance/_bridge.py b/guidance/_bridge.py index 1fa8eb7eb..96f1d083b 100644 --- a/guidance/_bridge.py +++ b/guidance/_bridge.py @@ -69,10 +69,7 @@ def run_async_maybe_in_thread( try: asyncio.get_running_loop() except RuntimeError: - # Not in an async context, run the coroutine in a thread - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop.run_until_complete(coro) + return asyncio.run(coro) else: warnings.warn( "Synchronous access to model state from an async context is ill-advised...", From 63bb6558ec57f0bc35b7836aa75921dbc1ec68d6 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Thu, 17 Apr 2025 10:51:38 -0700 Subject: [PATCH 19/56] don't try to await_ in a thread --- guidance/_bridge.py | 49 +++++++++------------------------------------ 1 file changed, 9 insertions(+), 40 deletions(-) diff --git a/guidance/_bridge.py b/guidance/_bridge.py index 96f1d083b..14d33e73a 100644 --- a/guidance/_bridge.py +++ b/guidance/_bridge.py @@ -14,7 +14,7 @@ P = ParamSpec("P") T = TypeVar("T") -class AwaitException(Exception): +class AwaitException(RuntimeError): """Exception raised when a coroutine is awaited in a non-greenlet context.""" pass @@ -33,7 +33,14 @@ def await_(coro: Awaitable[T]) -> T: """ parent_gl = getcurrent().parent if parent_gl is None: - return run_async_maybe_in_thread(coro) + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(coro) + else: + raise AwaitException( + "Cannot use synchronous await_ within a running event loop." + ) return cast(T, parent_gl.switch(coro)) def async_(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]: @@ -58,41 +65,3 @@ async def decorator(*args: P.args, **kwargs: P.kwargs) -> T: coro = cast(Union[T, Awaitable[T]], gl.switch(result)) return coro return decorator - - -def run_async_maybe_in_thread( - coro: Awaitable[T] -) -> T: - """ - Run a coroutine in a thread if not already in an async context. - """ - try: - asyncio.get_running_loop() - except RuntimeError: - return asyncio.run(coro) - else: - warnings.warn( - "Synchronous access to model state from an async context is ill-advised...", - stacklevel=2 - ) - # We're already in an async loop, so we have to run the coroutine in a nested event loop. - # TODO: consider raising an exception (sync guidance function called in async context) - # TODO: consider using some global thread and call asyncio.run_coroutine_threadsafe - result: Optional[T] = None - exception: Optional[Exception] = None - event = threading.Event() - def run(): - nonlocal result, exception - try: - result = cast(T, asyncio.run(coro)) - except Exception as ex: - exception = ex - finally: - event.set() - thread = threading.Thread(target=run) - thread.start() - event.wait() - thread.join() - if exception is not None: - raise exception - return cast(T, result) From ffa5c6d107f53ed2143f2a13ca5cf0ec93dac705 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Thu, 17 Apr 2025 10:53:16 -0700 Subject: [PATCH 20/56] close the coro if we're not going to run it --- guidance/_bridge.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/guidance/_bridge.py b/guidance/_bridge.py index 14d33e73a..8826ee45d 100644 --- a/guidance/_bridge.py +++ b/guidance/_bridge.py @@ -8,7 +8,7 @@ import asyncio import threading import warnings -from typing import Awaitable, TypeVar, Callable, Union, Optional, cast +from typing import Any, Awaitable, Coroutine, TypeVar, Callable, Union, Optional, cast from typing_extensions import ParamSpec, Never P = ParamSpec("P") @@ -18,7 +18,7 @@ class AwaitException(RuntimeError): """Exception raised when a coroutine is awaited in a non-greenlet context.""" pass -def await_(coro: Awaitable[T]) -> T: +def await_(coro: Coroutine[Any, Any, T]) -> T: """ Sends a coroutine to the parent greenlet. The parent greenlet should either 1. await the coroutine in an async function @@ -38,6 +38,8 @@ def await_(coro: Awaitable[T]) -> T: except RuntimeError: return asyncio.run(coro) else: + # Close the coro to avoid leaking resources + coro.close() raise AwaitException( "Cannot use synchronous await_ within a running event loop." ) From ba547f62073b97b64a24fae84a8b61abc1815ba6 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Thu, 17 Apr 2025 10:56:27 -0700 Subject: [PATCH 21/56] black, isort, mypy --- guidance/_bridge.py | 47 ++++++++++++++++++++++++--------------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/guidance/_bridge.py b/guidance/_bridge.py index 8826ee45d..cd8d3a5ca 100644 --- a/guidance/_bridge.py +++ b/guidance/_bridge.py @@ -3,33 +3,35 @@ https://github.com/miguelgrinberg/greenletio """ -from greenlet import greenlet, getcurrent -import sys import asyncio -import threading -import warnings -from typing import Any, Awaitable, Coroutine, TypeVar, Callable, Union, Optional, cast -from typing_extensions import ParamSpec, Never +import sys +from typing import Any, Callable, Coroutine, TypeVar, cast + +from greenlet import getcurrent, greenlet # type: ignore[import-untyped] +from typing_extensions import ParamSpec P = ParamSpec("P") T = TypeVar("T") + class AwaitException(RuntimeError): """Exception raised when a coroutine is awaited in a non-greenlet context.""" + pass + def await_(coro: Coroutine[Any, Any, T]) -> T: """ - Sends a coroutine to the parent greenlet. The parent greenlet should either - 1. await the coroutine in an async function - 2. await_ the coroutine in a sync function, punting the problem further up + Sends a coroutine to the parent greenlet. The parent greenlet should either + 1. await the coroutine in an async function + 2. await_ the coroutine in a sync function, punting the problem further up - If there is no parent greenlet, we'll call the usual asyncio.run() to run the - coroutine. + If there is no parent greenlet, we'll call the usual asyncio.run() to run the + coroutine. - If this fails due to an existing event loop, that means that the caller is a - foreign async function (not one of "ours"), and they need to use one of our - async entry-points to run the coroutine with await. + If this fails due to an existing event loop, that means that the caller is a + foreign async function (not one of "ours"), and they need to use one of our + async entry-points to run the coroutine with await. """ parent_gl = getcurrent().parent if parent_gl is None: @@ -40,30 +42,31 @@ def await_(coro: Coroutine[Any, Any, T]) -> T: else: # Close the coro to avoid leaking resources coro.close() - raise AwaitException( - "Cannot use synchronous await_ within a running event loop." - ) + raise AwaitException("Cannot use synchronous await_ within a running event loop.") return cast(T, parent_gl.switch(coro)) -def async_(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]: + +def async_(fn: Callable[P, T]) -> Callable[P, Coroutine[Any, Any, T]]: """ Decorator to convert a synchronous function into an asynchronous one. If `await_` is called somewhere down the call stack, we are prepared to receive it and take responsibility for awaiting the coroutine. """ + async def decorator(*args: P.args, **kwargs: P.kwargs) -> T: gl = greenlet(fn) - coro: Union[T, Awaitable[T]] = gl.switch(*args, **kwargs) + coro = gl.switch(*args, **kwargs) while gl: - coro = cast(Awaitable[T], coro) + coro = coro try: result = await coro except: # noqa: E722 # this catches exceptions from async functions awaited in # sync code, and re-raises them in the greenlet - coro = cast(Never, gl.throw(*sys.exc_info())) + coro = gl.throw(*sys.exc_info()) else: - coro = cast(Union[T, Awaitable[T]], gl.switch(result)) + coro = gl.switch(result) return coro + return decorator From 427f714f8938f4251e3d5fc7f0606c78ca26251a Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Thu, 17 Apr 2025 11:52:43 -0700 Subject: [PATCH 22/56] assert we actually get an exception when using sync accessors in async guidance functions --- tests/unit/test_async.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_async.py b/tests/unit/test_async.py index 22313814d..11683673e 100644 --- a/tests/unit/test_async.py +++ b/tests/unit/test_async.py @@ -1,6 +1,7 @@ import asyncio import pytest -from guidance import * +from guidance import guidance, models, select +from guidance._bridge import AwaitException @guidance def sync_func(lm: models.Model): @@ -100,8 +101,10 @@ def test_async(gfunc, expected, sync): ], ) def test_async_with_sync_accessor(sync): - s = run(async_func_with_sync_accessor, sync) - assert s in {"alpha", "beta"} + # This should raise an AwaitException because the sync accessor is not + # allowed in the async function + with pytest.raises(AwaitException): + run(async_func_with_sync_accessor, sync) def test_sync_accessor_in_foreign_event_loop(): async def main(): From 394d4154e9ce76b67e638200d0714ace79786a85 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Mon, 21 Apr 2025 12:48:21 -0700 Subject: [PATCH 23/56] copy context vars into greenlet --- guidance/_bridge.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/guidance/_bridge.py b/guidance/_bridge.py index cd8d3a5ca..7f9353d78 100644 --- a/guidance/_bridge.py +++ b/guidance/_bridge.py @@ -5,6 +5,7 @@ import asyncio import sys +import contextvars from typing import Any, Callable, Coroutine, TypeVar, cast from greenlet import getcurrent, greenlet # type: ignore[import-untyped] @@ -56,6 +57,7 @@ def async_(fn: Callable[P, T]) -> Callable[P, Coroutine[Any, Any, T]]: async def decorator(*args: P.args, **kwargs: P.kwargs) -> T: gl = greenlet(fn) + gl.gr_context = contextvars.copy_context() coro = gl.switch(*args, **kwargs) while gl: coro = coro From 70ee3f9dff20153e02540b6d6e0816418ecbe506 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Mon, 21 Apr 2025 13:46:31 -0700 Subject: [PATCH 24/56] more tests --- tests/unit/test_async.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_async.py b/tests/unit/test_async.py index 11683673e..828dc0919 100644 --- a/tests/unit/test_async.py +++ b/tests/unit/test_async.py @@ -22,7 +22,8 @@ async def async_func(lm: models.Model): return lm @guidance -async def async_func_with_sync_accessor(lm: models.Model): +def sync_func_calling_sync_func(lm: models.Model): + lm += sync_func() lm += select(["a", "b"], name="choice") if lm.get("choice") == "a": lm += "lpha" @@ -40,6 +41,16 @@ def sync_func_calling_async_func(lm: models.Model): lm += "eta" return lm +@guidance +async def async_func_calling_async_func(lm: models.Model): + lm += async_func() + lm += select(["a", "b"], name="choice") + if (await lm.get_async("choice")) == "a": + lm += "lpha" + else: + lm += "eta" + return lm + @guidance async def async_func_calling_sync_func(lm: models.Model): lm += sync_func() @@ -85,6 +96,8 @@ def run(gfunc, sync: bool) -> str: (async_func, {"alpha", "beta"}), (sync_then_async, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), (async_then_sync, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), + (sync_func_calling_sync_func, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), + (async_func_calling_async_func, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), (async_func_calling_sync_func, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), (sync_func_calling_async_func, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), ], @@ -101,6 +114,14 @@ def test_async(gfunc, expected, sync): ], ) def test_async_with_sync_accessor(sync): + @guidance + async def async_func_with_sync_accessor(lm: models.Model): + lm += select(["a", "b"], name="choice") + if lm.get("choice") == "a": + lm += "lpha" + else: + lm += "eta" + return lm # This should raise an AwaitException because the sync accessor is not # allowed in the async function with pytest.raises(AwaitException): From 850ebd1a5c434c0cc29d2687519cd7fc41da3395 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Mon, 21 Apr 2025 14:09:25 -0700 Subject: [PATCH 25/56] entry point decorator to help determine whether to run reentrant awaits in bg event loop --- guidance/_bridge.py | 33 +++++++++++++++++++++++++++------ guidance/models/_base/_model.py | 5 +++-- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/guidance/_bridge.py b/guidance/_bridge.py index 7f9353d78..a22ac5dea 100644 --- a/guidance/_bridge.py +++ b/guidance/_bridge.py @@ -5,8 +5,11 @@ import asyncio import sys +from contextlib import contextmanager +from functools import wraps import contextvars -from typing import Any, Callable, Coroutine, TypeVar, cast +from typing import Any, Callable, Coroutine, TypeVar, cast, Awaitable +import inspect from greenlet import getcurrent, greenlet # type: ignore[import-untyped] from typing_extensions import ParamSpec @@ -14,6 +17,18 @@ P = ParamSpec("P") T = TypeVar("T") +_entered: contextvars.ContextVar[bool] = contextvars.ContextVar("entered", default=False) + +def async_entry_point(func: Callable[P, Awaitable[T]]) -> Callable[P, Coroutine[Any, Any, T]]: + @wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + token = _entered.set(True) + try: + return await func(*args, **kwargs) + finally: + _entered.reset(token) + return wrapper + class AwaitException(RuntimeError): """Exception raised when a coroutine is awaited in a non-greenlet context.""" @@ -36,12 +51,18 @@ def await_(coro: Coroutine[Any, Any, T]) -> T: """ parent_gl = getcurrent().parent if parent_gl is None: - try: - asyncio.get_running_loop() - except RuntimeError: - return asyncio.run(coro) + if not _entered.get(): + from .registry import get_bg_async + import threading + bg_async = get_bg_async() + thread, _ = bg_async._thread_and_loop() + if thread is threading.current_thread(): + raise RuntimeError( + "Cannot nest async call -- already in background thread." + ) + fut = bg_async.run_async_coroutine(coro) + return fut.result() else: - # Close the coro to avoid leaking resources coro.close() raise AwaitException("Cannot use synchronous await_ within a running event loop.") return cast(T, parent_gl.switch(coro)) diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index 9a8e5a2f4..97c044328 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -7,7 +7,7 @@ from copy import deepcopy from typing import TYPE_CHECKING, Any, Iterator, Optional, TypeVar, Union import warnings -from ..._bridge import async_, await_ +from ..._bridge import async_, await_, async_entry_point from typing_extensions import Self @@ -261,6 +261,7 @@ def __getattribute__(self, name): return getattr(self._interpreter, "engine") return super().__getattribute__(name) + @async_entry_point async def _run(self) -> None: new_self = self.copy() # may be some pending blocks @@ -286,7 +287,7 @@ async def _run(self) -> None: await self._run_node(node) def _run_sync(self) -> None: - await_(self._run()) + return await_(self._run()) async def _run_node(self, node: ASTNode) -> None: async for output_attr in self._interpreter.run(node): From 61c28eea1b2ea2eb0349895ed49e47aad90871f1 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Mon, 21 Apr 2025 14:59:19 -0700 Subject: [PATCH 26/56] factor out run in bg thread --- guidance/_bridge.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/guidance/_bridge.py b/guidance/_bridge.py index a22ac5dea..441fec986 100644 --- a/guidance/_bridge.py +++ b/guidance/_bridge.py @@ -3,13 +3,11 @@ https://github.com/miguelgrinberg/greenletio """ -import asyncio import sys -from contextlib import contextmanager from functools import wraps import contextvars +import threading from typing import Any, Callable, Coroutine, TypeVar, cast, Awaitable -import inspect from greenlet import getcurrent, greenlet # type: ignore[import-untyped] from typing_extensions import ParamSpec @@ -29,6 +27,19 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: _entered.reset(token) return wrapper +def run_coro_in_bg_thread(coro: Coroutine[Any, Any, T]) -> T: + """ + Run a coroutine in the background thread and return the result. + This is a blocking call. + """ + from .registry import get_bg_async + + bg_async = get_bg_async() + thread, _ = bg_async._thread_and_loop() + if thread is threading.current_thread(): + raise RuntimeError("Cannot nest async call -- already in background thread.") + fut = bg_async.run_async_coroutine(coro) + return fut.result() class AwaitException(RuntimeError): """Exception raised when a coroutine is awaited in a non-greenlet context.""" @@ -52,16 +63,7 @@ def await_(coro: Coroutine[Any, Any, T]) -> T: parent_gl = getcurrent().parent if parent_gl is None: if not _entered.get(): - from .registry import get_bg_async - import threading - bg_async = get_bg_async() - thread, _ = bg_async._thread_and_loop() - if thread is threading.current_thread(): - raise RuntimeError( - "Cannot nest async call -- already in background thread." - ) - fut = bg_async.run_async_coroutine(coro) - return fut.result() + return run_coro_in_bg_thread(coro) else: coro.close() raise AwaitException("Cannot use synchronous await_ within a running event loop.") From 7698c266821efb9194c1c04809899b988e139c1f Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Mon, 21 Apr 2025 15:58:36 -0700 Subject: [PATCH 27/56] clean up bridge a bit --- guidance/_bridge.py | 74 ++++++++++----------------------- guidance/models/_base/_model.py | 69 +++++++++++++++++++----------- tests/unit/test_async.py | 4 +- 3 files changed, 67 insertions(+), 80 deletions(-) diff --git a/guidance/_bridge.py b/guidance/_bridge.py index 441fec986..fee8aee14 100644 --- a/guidance/_bridge.py +++ b/guidance/_bridge.py @@ -6,79 +6,47 @@ import sys from functools import wraps import contextvars -import threading -from typing import Any, Callable, Coroutine, TypeVar, cast, Awaitable +from typing import Any, Callable, Coroutine, TypeVar, cast +from functools import wraps from greenlet import getcurrent, greenlet # type: ignore[import-untyped] from typing_extensions import ParamSpec -P = ParamSpec("P") -T = TypeVar("T") - -_entered: contextvars.ContextVar[bool] = contextvars.ContextVar("entered", default=False) - -def async_entry_point(func: Callable[P, Awaitable[T]]) -> Callable[P, Coroutine[Any, Any, T]]: - @wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: - token = _entered.set(True) - try: - return await func(*args, **kwargs) - finally: - _entered.reset(token) - return wrapper - -def run_coro_in_bg_thread(coro: Coroutine[Any, Any, T]) -> T: - """ - Run a coroutine in the background thread and return the result. - This is a blocking call. - """ - from .registry import get_bg_async - - bg_async = get_bg_async() - thread, _ = bg_async._thread_and_loop() - if thread is threading.current_thread(): - raise RuntimeError("Cannot nest async call -- already in background thread.") - fut = bg_async.run_async_coroutine(coro) - return fut.result() -class AwaitException(RuntimeError): +class ReentrantAsyncException(RuntimeError): """Exception raised when a coroutine is awaited in a non-greenlet context.""" - pass -def await_(coro: Coroutine[Any, Any, T]) -> T: - """ - Sends a coroutine to the parent greenlet. The parent greenlet should either - 1. await the coroutine in an async function - 2. await_ the coroutine in a sync function, punting the problem further up +P = ParamSpec("P") +T = TypeVar("T") - If there is no parent greenlet, we'll call the usual asyncio.run() to run the - coroutine. - If this fails due to an existing event loop, that means that the caller is a - foreign async function (not one of "ours"), and they need to use one of our - async entry-points to run the coroutine with await. +def reentrant_await(coro: Coroutine[Any, Any, T]) -> T: """ + Sends a coroutine to the parent greenlet, which is expected to await the coroutine + for us and send the result back. + + When there is no parent greenlet, we raise a ReentrantAsyncException. + """ + parent_gl = getcurrent().parent if parent_gl is None: - if not _entered.get(): - return run_coro_in_bg_thread(coro) - else: - coro.close() - raise AwaitException("Cannot use synchronous await_ within a running event loop.") + coro.close() + raise ReentrantAsyncException("Attempted to use synchronous entry-point in async context") return cast(T, parent_gl.switch(coro)) -def async_(fn: Callable[P, T]) -> Callable[P, Coroutine[Any, Any, T]]: +def sync_to_reentrant_async(fn: Callable[P, T]) -> Callable[P, Coroutine[Any, Any, T]]: """ - Decorator to convert a synchronous function into an asynchronous one. + Decorator to convert a synchronous function into a re-entrant asynchronous one. - If `await_` is called somewhere down the call stack, we are prepared to - receive it and take responsibility for awaiting the coroutine. + Calls to `reentrant_await` down the stack will bounce back here, and we we'll await + the coroutine for them """ - async def decorator(*args: P.args, **kwargs: P.kwargs) -> T: + @wraps(fn) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: gl = greenlet(fn) gl.gr_context = contextvars.copy_context() coro = gl.switch(*args, **kwargs) @@ -94,4 +62,4 @@ async def decorator(*args: P.args, **kwargs: P.kwargs) -> T: coro = gl.switch(result) return coro - return decorator + return wrapper diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index 97c044328..a49f36b07 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -7,7 +7,7 @@ from copy import deepcopy from typing import TYPE_CHECKING, Any, Iterator, Optional, TypeVar, Union import warnings -from ..._bridge import async_, await_, async_entry_point +from ..._bridge import sync_to_reentrant_async, reentrant_await from typing_extensions import Self @@ -43,6 +43,7 @@ if TYPE_CHECKING: from ...library._block import Block +_below_entry_point: ContextVar[bool] = ContextVar("below_entry_point", default=False) _active_blocks: ContextVar[tuple["Block", ...]] = ContextVar("active_blocks", default=()) _event_queues: ContextVar[tuple[queue.Queue["Model"], ...]] = ContextVar( "event_queues", default=() @@ -261,33 +262,51 @@ def __getattribute__(self, name): return getattr(self._interpreter, "engine") return super().__getattribute__(name) - @async_entry_point async def _run(self) -> None: - new_self = self.copy() - # may be some pending blocks - new_self._apply_blocks() - while isinstance(new_self._pending, (Function, AsyncFunction)): - func = new_self._pending - new_self._pending = None - if isinstance(func, AsyncFunction): - new_self = await func(new_self) - else: - # If someone awaits us directly (i.e. we're not below an `await_`), - # we need to wrap the sync part in `async_` to avoid blocking our caller's - # event loop. - # Otherwise, this is effectively equivalent to func(new_self) - new_self = await async_(func)(new_self) - self.__dict__ = new_self.__dict__ # I guess - if self._pending is None: - return - - assert isinstance(self._pending, ASTNode) - node = self._pending - self._pending = None - await self._run_node(node) + async def inner(): + new_self = self.copy() + # may be some pending blocks + new_self._apply_blocks() + while isinstance(new_self._pending, (Function, AsyncFunction)): + func = new_self._pending + new_self._pending = None + if isinstance(func, AsyncFunction): + new_self = await func(new_self) + else: + # If someone awaits us directly (i.e. we're not below an `await_`), + # we need to wrap the sync part in `async_` to avoid blocking our caller's + # event loop. + # Otherwise, this is effectively equivalent to func(new_self) + new_self = await sync_to_reentrant_async(func)(new_self) + self.__dict__ = new_self.__dict__ # I guess + if self._pending is None: + return + + assert isinstance(self._pending, ASTNode) + node = self._pending + self._pending = None + await self._run_node(node) + + # Mark that we are below the entry point so that + # `_run_sync` knows to use `await_` instead of + # running in the background thread. + token = _below_entry_point.set(True) + try: + return await inner() + finally: + _below_entry_point.reset(token) def _run_sync(self) -> None: - return await_(self._run()) + if not _below_entry_point.get(): + from ...registry import get_bg_async + bg_async = get_bg_async() + thread, _ = bg_async._thread_and_loop() + if thread is threading.current_thread(): + raise RuntimeError("Cannot nest async call -- already in background thread.") + fut = bg_async.run_async_coroutine(self._run()) + return fut.result() + + return reentrant_await(self._run()) async def _run_node(self, node: ASTNode) -> None: async for output_attr in self._interpreter.run(node): diff --git a/tests/unit/test_async.py b/tests/unit/test_async.py index 828dc0919..6c08fe890 100644 --- a/tests/unit/test_async.py +++ b/tests/unit/test_async.py @@ -1,7 +1,7 @@ import asyncio import pytest from guidance import guidance, models, select -from guidance._bridge import AwaitException +from guidance._bridge import ReentrantAsyncException @guidance def sync_func(lm: models.Model): @@ -124,7 +124,7 @@ async def async_func_with_sync_accessor(lm: models.Model): return lm # This should raise an AwaitException because the sync accessor is not # allowed in the async function - with pytest.raises(AwaitException): + with pytest.raises(ReentrantAsyncException): run(async_func_with_sync_accessor, sync) def test_sync_accessor_in_foreign_event_loop(): From ce3d7a063ba347ee13ca3e4efa400fdcb5fe46a7 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Mon, 21 Apr 2025 16:13:02 -0700 Subject: [PATCH 28/56] make bg_async generic --- guidance/_bg/__init__.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/guidance/_bg/__init__.py b/guidance/_bg/__init__.py index c6f150624..0490e7d4e 100644 --- a/guidance/_bg/__init__.py +++ b/guidance/_bg/__init__.py @@ -6,15 +6,16 @@ import asyncio import threading from asyncio import AbstractEventLoop, Future, Task -from typing import Tuple, Coroutine +from typing import Coroutine, Any, TypeVar +T = TypeVar("T") def _start_asyncio_loop(loop: AbstractEventLoop): asyncio.set_event_loop(loop) loop.run_forever() -def _asyncio_background_thread() -> Tuple[threading.Thread, AbstractEventLoop]: +def _asyncio_background_thread() -> tuple[threading.Thread, AbstractEventLoop]: loop = asyncio.new_event_loop() thread = threading.Thread(target=_start_asyncio_loop, args=(loop,)) thread.daemon = True @@ -29,7 +30,7 @@ def __init__(self): self._loop = None self._thread = None - def _thread_and_loop(self) -> Tuple[threading.Thread, AbstractEventLoop]: + def _thread_and_loop(self) -> tuple[threading.Thread, AbstractEventLoop]: if self._loop is None: self._thread, self._loop = _asyncio_background_thread() self._thread.start() @@ -41,7 +42,7 @@ def call_soon_threadsafe(self, cb, *args, context = None): _, loop = self._thread_and_loop() return loop.call_soon_threadsafe(cb, *args, context=context) - def run_async_coroutine(self, coroutine: Coroutine) -> Future: + def run_async_coroutine(self, coroutine: Coroutine[Any, Any, T]) -> Future[T]: """ Runs an asynchronous coroutine in the visual thread. Args: @@ -55,7 +56,7 @@ def run_async_coroutine(self, coroutine: Coroutine) -> Future: return future @staticmethod - async def async_task(coroutine: Coroutine) -> Task: + async def async_task(coroutine: Coroutine[Any, Any, T]) -> Task[T]: """ Creates an asyncio task from coroutine. Args: From 4839742b72bbab01d989f0495cb58c3b33153d02 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Mon, 21 Apr 2025 16:14:56 -0700 Subject: [PATCH 29/56] move run in bg async to bridge --- guidance/_bridge.py | 22 ++++++++++++++++++++-- guidance/models/_base/_model.py | 11 ++--------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/guidance/_bridge.py b/guidance/_bridge.py index fee8aee14..37536a147 100644 --- a/guidance/_bridge.py +++ b/guidance/_bridge.py @@ -3,11 +3,11 @@ https://github.com/miguelgrinberg/greenletio """ +import contextvars import sys +import threading from functools import wraps -import contextvars from typing import Any, Callable, Coroutine, TypeVar, cast -from functools import wraps from greenlet import getcurrent, greenlet # type: ignore[import-untyped] from typing_extensions import ParamSpec @@ -15,6 +15,7 @@ class ReentrantAsyncException(RuntimeError): """Exception raised when a coroutine is awaited in a non-greenlet context.""" + pass @@ -63,3 +64,20 @@ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: return coro return wrapper + + +def run_async_coroutine_in_bg_async(coro: Coroutine[Any, Any, T]) -> T: + """ + Run a coroutine in the background thread and wait for it to finish. + (This is a blocking call.) + """ + + from .registry import get_bg_async + + bg_async = get_bg_async() + thread, _ = bg_async._thread_and_loop() + if thread is threading.current_thread(): + coro.close() + raise RuntimeError("Cannot nest async call -- already in background thread.") + fut = bg_async.run_async_coroutine(coro) + return fut.result() diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index a49f36b07..7833227d2 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -7,7 +7,7 @@ from copy import deepcopy from typing import TYPE_CHECKING, Any, Iterator, Optional, TypeVar, Union import warnings -from ..._bridge import sync_to_reentrant_async, reentrant_await +from ..._bridge import sync_to_reentrant_async, reentrant_await, run_async_coroutine_in_bg_async from typing_extensions import Self @@ -298,14 +298,7 @@ async def inner(): def _run_sync(self) -> None: if not _below_entry_point.get(): - from ...registry import get_bg_async - bg_async = get_bg_async() - thread, _ = bg_async._thread_and_loop() - if thread is threading.current_thread(): - raise RuntimeError("Cannot nest async call -- already in background thread.") - fut = bg_async.run_async_coroutine(self._run()) - return fut.result() - + return run_async_coroutine_in_bg_async(self._run()) return reentrant_await(self._run()) async def _run_node(self, node: ASTNode) -> None: From 4d8d3eec710b403ba7d4f0bf12f6b65d000ddbe2 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Mon, 21 Apr 2025 16:15:32 -0700 Subject: [PATCH 30/56] guidance/_bridge.py -> guidance/_reentrant_async.py --- guidance/{_bridge.py => _reentrant_async.py} | 0 guidance/models/_base/_model.py | 2 +- tests/unit/test_async.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename guidance/{_bridge.py => _reentrant_async.py} (100%) diff --git a/guidance/_bridge.py b/guidance/_reentrant_async.py similarity index 100% rename from guidance/_bridge.py rename to guidance/_reentrant_async.py diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index 7833227d2..1493ce5c7 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -7,7 +7,7 @@ from copy import deepcopy from typing import TYPE_CHECKING, Any, Iterator, Optional, TypeVar, Union import warnings -from ..._bridge import sync_to_reentrant_async, reentrant_await, run_async_coroutine_in_bg_async +from ..._reentrant_async import sync_to_reentrant_async, reentrant_await, run_async_coroutine_in_bg_async from typing_extensions import Self diff --git a/tests/unit/test_async.py b/tests/unit/test_async.py index 6c08fe890..abcbe5702 100644 --- a/tests/unit/test_async.py +++ b/tests/unit/test_async.py @@ -1,7 +1,7 @@ import asyncio import pytest from guidance import guidance, models, select -from guidance._bridge import ReentrantAsyncException +from guidance._reentrant_async import ReentrantAsyncException @guidance def sync_func(lm: models.Model): From 8c288734d34fb44ab0852e5cc842bf9a6db40feb Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Mon, 21 Apr 2025 16:39:42 -0700 Subject: [PATCH 31/56] clear active blocks in repeated function application --- guidance/models/_base/_model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index 1493ce5c7..cf8f11afb 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -270,6 +270,7 @@ async def inner(): while isinstance(new_self._pending, (Function, AsyncFunction)): func = new_self._pending new_self._pending = None + new_self._active_blocks = () if isinstance(func, AsyncFunction): new_self = await func(new_self) else: From 06c309c42b0fff970ff304570ea3cc44ac81e4e3 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 22 Apr 2025 09:25:31 -0700 Subject: [PATCH 32/56] make sure to clear pending blocks after function application --- guidance/models/_base/_model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index cf8f11afb..1dab44ab4 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -279,6 +279,8 @@ async def inner(): # event loop. # Otherwise, this is effectively equivalent to func(new_self) new_self = await sync_to_reentrant_async(func)(new_self) + # may be some pending blocks + new_self._apply_blocks() self.__dict__ = new_self.__dict__ # I guess if self._pending is None: return From 111efc0929575ca83af811c015e5c45eea012304 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 22 Apr 2025 09:45:50 -0700 Subject: [PATCH 33/56] add 'batched' entrypoints --- guidance/models/_base/_model.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index 1dab44ab4..872fdce2e 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -5,7 +5,7 @@ import asyncio from contextvars import ContextVar, copy_context from copy import deepcopy -from typing import TYPE_CHECKING, Any, Iterator, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Iterator, Optional, TypeVar, Union, Sequence import warnings from ..._reentrant_async import sync_to_reentrant_async, reentrant_await, run_async_coroutine_in_bg_async @@ -262,6 +262,17 @@ def __getattribute__(self, name): return getattr(self._interpreter, "engine") return super().__getattribute__(name) + async def run_batched_async(self, items: Sequence[Union[str, Function, AsyncFunction, ASTNode]]) -> Self: + lms = [self + item for item in items] + coros = [lm._run() for lm in lms] + await asyncio.gather(*coros) + return lms + + def run_batched(self, items: Sequence[Union[str, Function, AsyncFunction, ASTNode]]) -> Self: + if not _below_entry_point.get(): + return run_async_coroutine_in_bg_async(self.run_batched_async(items)) + return reentrant_await(self.run_batched_async(items)) + async def _run(self) -> None: async def inner(): new_self = self.copy() From bc389e49affc52ed9e3fdac6925c1d277d7fb8da Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 22 Apr 2025 10:02:57 -0700 Subject: [PATCH 34/56] add greenlet dependency --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index b41ef420d..f2a87bcf1 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ "pydantic", "requests", "psutil", + "greenlet", "guidance-stitch", "llguidance==0.7.16", ] From c9a922b9b4d98d94e48e038efc8c0e48eac3e094 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 22 Apr 2025 14:36:51 -0700 Subject: [PATCH 35/56] fix capture blocks --- guidance/_ast.py | 1 - guidance/models/_base/_interpreter.py | 11 +++++++++++ guidance/models/_base/_model.py | 6 +++--- guidance/models/_base/_state.py | 14 ++++++++++++++ 4 files changed, 28 insertions(+), 4 deletions(-) diff --git a/guidance/_ast.py b/guidance/_ast.py index b4d25dfd1..247609a2b 100644 --- a/guidance/_ast.py +++ b/guidance/_ast.py @@ -264,7 +264,6 @@ def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: @dataclass class CaptureStart(ASTNode): name: str - list_append: bool = False def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: return interpreter.capture_start(self, **kwargs) diff --git a/guidance/models/_base/_interpreter.py b/guidance/models/_base/_interpreter.py index 136878313..f1415cd13 100644 --- a/guidance/models/_base/_interpreter.py +++ b/guidance/models/_base/_interpreter.py @@ -22,6 +22,8 @@ SelectNode, SubgrammarNode, SubstringNode, + CaptureStart, + CaptureEnd, ) from ..._utils import bytes_from from ...trace import OutputAttr @@ -36,6 +38,15 @@ def __init__(self, state: S): def run(self, node: ASTNode, **kwargs) -> AsyncIterable[OutputAttr]: return node.simplify()._run(self, **kwargs) + async def capture_start(self, node: CaptureStart, **kwargs) -> AsyncIterable[OutputAttr]: + self.state.open_capture(node.name) + if False: + # Yes, this is intentional. + yield + + async def capture_end(self, node: CaptureEnd, **kwargs) -> AsyncIterable[OutputAttr]: + yield self.state.close_capture(node.name) + async def concatenate(self, node: Concatenate, **kwargs) -> AsyncIterable[OutputAttr]: buffer: Optional[GrammarNode] = None for child in node: diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index 872fdce2e..b0cf8a31d 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -136,8 +136,8 @@ def _apply_blocks(self) -> None: if isinstance(closer, str): closer = _parse_tags(closer) self._add_to_pending(closer) - if block.name is not None: - self._add_to_pending(CaptureStart(name=block.name)) + if block.name is not None: + self._add_to_pending(CaptureEnd(name=block.name)) else: # Not closed, so keep it new_active_blocks.append(block) @@ -147,7 +147,7 @@ def _apply_blocks(self) -> None: if block not in self._active_blocks: new_active_blocks.append(block) if block.name is not None: - self._add_to_pending(CaptureEnd(name=block.name)) + self._add_to_pending(CaptureStart(name=block.name)) if block.opener is not None: opener = block.opener if isinstance(opener, str): diff --git a/guidance/models/_base/_state.py b/guidance/models/_base/_state.py index 38dd79de5..2e05b481c 100644 --- a/guidance/models/_base/_state.py +++ b/guidance/models/_base/_state.py @@ -13,6 +13,20 @@ class State(ABC): def __init__(self) -> None: self.captures: dict[str, Union[CaptureVar, list[CaptureVar]]] = {} self.active_role: Optional[str] = None + self.open_capture_blocks: dict[str, int] = {} + + def open_capture(self, name: str) -> None: + self.open_capture_blocks[name] = len(str(self)) + + def close_capture(self, name: str) -> CaptureOutput: + start_index = self.open_capture_blocks.pop(name) + value = str(self)[start_index:] + return self.apply_capture( + name=name, + value=value, + log_prob=None, + is_append=False, + ) @abstractmethod def __str__(self) -> str: From 7a750719e01c0ccfac0a6f451f255a24ade90222 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 22 Apr 2025 14:44:10 -0700 Subject: [PATCH 36/56] fix smoke tests by adding str(lm) to trigger execution --- tests/model_integration/library/test_gen.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/model_integration/library/test_gen.py b/tests/model_integration/library/test_gen.py index 3a03aa91c..7653284b8 100644 --- a/tests/model_integration/library/test_gen.py +++ b/tests/model_integration/library/test_gen.py @@ -52,8 +52,10 @@ def test_metrics_smoke(selected_model: models.Model): lm.engine.reset_metrics() lm += "abcd" + str(lm) # trigger execution print(f"{lm.engine.metrics=}") lm += gen("first", max_tokens=1) + str(lm) # trigger execution print(f"{lm.engine.metrics=}") # Can't be sure of exact count due to token healing assert ( @@ -65,6 +67,7 @@ def test_metrics_smoke(selected_model: models.Model): lm += "fg" lm += gen("second", max_tokens=1) + str(lm) # trigger execution # Again, trouble with healing assert ( lm.engine.metrics.engine_output_tokens >= 2 @@ -85,7 +88,7 @@ def test_metrics_select(selected_model: models.Model): "go for a swim in the ocean", ] ) - print(f"lm={str(lm)}") + print(f"lm={str(lm)}") # trigger execution print(f"{lm.engine.metrics=}") assert lm.engine.metrics.engine_input_tokens > 1 assert lm.engine.metrics.engine_output_tokens > 0 @@ -106,6 +109,7 @@ def test_unicode(selected_model: models.Model): Step 1''' + gen('steps', list_append=True, stop=['\nStep', '\n\n', '\nAnswer'], temperature=0.7, max_tokens=20) + '\n' i = 2 lm + f'Step {i}:' + gen('steps', list_append=True, stop=['\nStep', '\n\n', '\nAnswer'], temperature=0.7, max_tokens=20) + '\n' + str(lm) # trigger execution # fmt: on @@ -114,6 +118,7 @@ def test_unicode2(selected_model: models.Model): lm.engine.reset_metrics() prompt = "Janet’s ducks lay 16 eggs per day" lm += prompt + gen(max_tokens=10) + str(lm) # trigger execution assert lm.engine.metrics.engine_input_tokens > 1 # Due to token healing, we can't be sure of the # precise output count From 788097b1509b7631b8baf53d27e0c79725197c5f Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 22 Apr 2025 14:45:08 -0700 Subject: [PATCH 37/56] visual async test -- don't rely on there being an existing event loop --- tests/unit/test_visual.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_visual.py b/tests/unit/test_visual.py index 4cfc82cd6..00620fbd6 100644 --- a/tests/unit/test_visual.py +++ b/tests/unit/test_visual.py @@ -8,6 +8,7 @@ from guidance.visual import serialize_message, deserialize_message from guidance.visual._environment import Environment import asyncio +import threading from guidance.visual._exchange import DEFAULT_TOPIC @@ -34,8 +35,12 @@ def test_serialization(message): def test_async(): - _, loop = get_bg_async()._thread_and_loop() - assert loop != asyncio.get_event_loop() + thread, loop = get_bg_async()._thread_and_loop() + try: + assert loop != asyncio.get_event_loop() + except RuntimeError: + pass + assert thread != threading.current_thread() async def f(): return True From aaa1191ef9a1c5df693514af7fc249427be3be02 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 22 Apr 2025 14:58:33 -0700 Subject: [PATCH 38/56] lm.get_token_count() --- guidance/library/_gen.py | 4 ++-- guidance/models/_base/_interpreter.py | 11 ++++++++--- guidance/models/_base/_model.py | 8 ++++++++ guidance/models/_base/_state.py | 1 + guidance/models/_engine/_interpreter.py | 2 +- tests/model_integration/test_model.py | 2 +- 6 files changed, 21 insertions(+), 7 deletions(-) diff --git a/guidance/library/_gen.py b/guidance/library/_gen.py index 88342cc28..6b7cd3488 100644 --- a/guidance/library/_gen.py +++ b/guidance/library/_gen.py @@ -133,10 +133,10 @@ def tool_gen(lm): ) grm = with_temperature(select(options), temperature) - initial_token_count = lm.token_count + initial_token_count = lm.get_token_count() tagged_name = "__LIST_APPEND:" + name if list_append and name is not None else name with block(tagged_name): - while lm.token_count <= max_tokens + initial_token_count: + while lm.get_token_count() <= max_tokens + initial_token_count: lm += grm tool_called = False for i in range(len(tools)): diff --git a/guidance/models/_base/_interpreter.py b/guidance/models/_base/_interpreter.py index f1415cd13..aa15d767f 100644 --- a/guidance/models/_base/_interpreter.py +++ b/guidance/models/_base/_interpreter.py @@ -26,7 +26,7 @@ CaptureEnd, ) from ..._utils import bytes_from -from ...trace import OutputAttr +from ...trace import OutputAttr, TextOutput from ._state import State S = TypeVar("S", bound=State) @@ -35,8 +35,13 @@ class Interpreter(Generic[S], ABC): def __init__(self, state: S): self.state = state - def run(self, node: ASTNode, **kwargs) -> AsyncIterable[OutputAttr]: - return node.simplify()._run(self, **kwargs) + async def run(self, node: ASTNode, **kwargs) -> AsyncIterable[OutputAttr]: + async for attr in node.simplify()._run(self, **kwargs): + if isinstance(attr, TextOutput) and attr.is_generated: + print(attr, attr.token_count) + # TODO: this should probably be a lower-level responsibility? Not sure. + self.state.token_count += attr.token_count + yield attr async def capture_start(self, node: CaptureStart, **kwargs) -> AsyncIterable[OutputAttr]: self.state.open_capture(node.name) diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index b0cf8a31d..ee6c4d594 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -350,6 +350,14 @@ async def length_async(self) -> int: """Get the length of the model.""" return len(await self.to_string_async()) + async def get_token_count_async(self) -> int: + """Get the token count of the model.""" + return (await self._get_state_async()).token_count + + def get_token_count(self) -> int: + """Get the token count of the model.""" + return self._get_state().token_count + class ModelStream: def __init__( self, diff --git a/guidance/models/_base/_state.py b/guidance/models/_base/_state.py index 2e05b481c..d27fdf3b9 100644 --- a/guidance/models/_base/_state.py +++ b/guidance/models/_base/_state.py @@ -14,6 +14,7 @@ def __init__(self) -> None: self.captures: dict[str, Union[CaptureVar, list[CaptureVar]]] = {} self.active_role: Optional[str] = None self.open_capture_blocks: dict[str, int] = {} + self.token_count: int = 0 def open_capture(self, name: str) -> None: self.open_capture_blocks[name] = len(str(self)) diff --git a/guidance/models/_engine/_interpreter.py b/guidance/models/_engine/_interpreter.py index b27563b25..72183c757 100644 --- a/guidance/models/_engine/_interpreter.py +++ b/guidance/models/_engine/_interpreter.py @@ -69,7 +69,7 @@ async def grammar(self, node: GrammarNode, **kwargs) -> AsyncIterable[OutputAttr # Update the state self.state.prompt += new_text - yield TextOutput(value=new_text, token_count=chunk.new_token_count, is_generated=True) + yield TextOutput(value=new_text, token_count=chunk.new_token_count, is_generated=chunk.is_generated) # TODO -- rewrite engine internals to make sure chunk.{generated,fast_forwarded}_tokens aren't empty... # # TODO: GenTokenExtra diff --git a/tests/model_integration/test_model.py b/tests/model_integration/test_model.py index affbe3c34..1f0430bac 100644 --- a/tests/model_integration/test_model.py +++ b/tests/model_integration/test_model.py @@ -26,7 +26,7 @@ def test_token_count(selected_model): lm = selected_model lm2 = lm + " 1 1 1 1 1" + gen(max_tokens=9) + gen(max_tokens=9) assert ( - 18 <= lm2.token_count <= 20 + 18 <= lm2.get_token_count() <= 20 ) # note we allow ourselves to be off by one because it is hard to know when we are continuing vs starting a new token in the parser From 07da88f73c19793375709abd582c9143dfe0113d Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 22 Apr 2025 15:14:19 -0700 Subject: [PATCH 39/56] fix associativity test -- use str to run lm --- tests/model_integration/test_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/model_integration/test_model.py b/tests/model_integration/test_model.py index 1f0430bac..83582b180 100644 --- a/tests/model_integration/test_model.py +++ b/tests/model_integration/test_model.py @@ -76,11 +76,15 @@ def test_associativity(selected_model: models.Model): engine = selected_model.engine with patch.object(engine, "get_logits", side_effect=engine.get_logits) as get_logits_1: - _ = selected_model + (prompt + grammar) + lm = selected_model + (prompt + grammar) + _ = str(lm) # trigger execution prompt_tokens_1 = get_logits_1.call_args_list[0].kwargs["token_ids"] with patch.object(engine, "get_logits", side_effect=engine.get_logits) as get_logits_2: - _ = (selected_model + prompt) + grammar + lm = selected_model + prompt + _ = str(lm) # trigger execution + lm += grammar + _ = str(lm) # trigger execution prompt_tokens_2 = get_logits_2.call_args_list[0].kwargs["token_ids"] # Main assertion: the prompt tokens should be the same From 58b4817ec66245a921f4ba8c81dbad44277f0dbe Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 22 Apr 2025 16:06:54 -0700 Subject: [PATCH 40/56] remove print --- guidance/models/_base/_interpreter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/guidance/models/_base/_interpreter.py b/guidance/models/_base/_interpreter.py index aa15d767f..45059da10 100644 --- a/guidance/models/_base/_interpreter.py +++ b/guidance/models/_base/_interpreter.py @@ -38,7 +38,6 @@ def __init__(self, state: S): async def run(self, node: ASTNode, **kwargs) -> AsyncIterable[OutputAttr]: async for attr in node.simplify()._run(self, **kwargs): if isinstance(attr, TextOutput) and attr.is_generated: - print(attr, attr.token_count) # TODO: this should probably be a lower-level responsibility? Not sure. self.state.token_count += attr.token_count yield attr From 4ad5a14b46f450ced8bbc5da6436903145540662 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 22 Apr 2025 17:00:23 -0700 Subject: [PATCH 41/56] remove leftover generic --- guidance/_ast.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/guidance/_ast.py b/guidance/_ast.py index 247609a2b..6476968ce 100644 --- a/guidance/_ast.py +++ b/guidance/_ast.py @@ -6,7 +6,6 @@ TYPE_CHECKING, Any, Callable, - Iterable, AsyncIterable, Optional, Sequence, @@ -195,11 +194,10 @@ async def __radd__(model): S = TypeVar("S", bound="State") -R = TypeVar("R", bound=Union[Iterable[OutputAttr], AsyncIterable[OutputAttr]]) class ASTNode(ABC): @abstractmethod - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: pass def simplify(self) -> "ASTNode": From d4fd6cb1c95fb5d9a667610860d655a8e83ed502 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 22 Apr 2025 17:02:28 -0700 Subject: [PATCH 42/56] remove leftover generic (missed some) --- guidance/_ast.py | 43 ++++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/guidance/_ast.py b/guidance/_ast.py index 6476968ce..b8b0d7b4f 100644 --- a/guidance/_ast.py +++ b/guidance/_ast.py @@ -7,6 +7,7 @@ Any, Callable, AsyncIterable, + Iterator, Optional, Sequence, TypeVar, @@ -233,10 +234,10 @@ def null(cls) -> "ASTNode": class Concatenate(ASTNode): nodes: tuple[ASTNode, ...] - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.concatenate(self, **kwargs) - def __iter__(self) -> Iterable[ASTNode]: + def __iter__(self) -> Iterator[ASTNode]: for node in self.nodes: if isinstance(node, Concatenate): yield from node @@ -247,7 +248,7 @@ def __iter__(self) -> Iterable[ASTNode]: class RoleStart(ASTNode): role: str - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter._role_start(self, **kwargs) @@ -255,7 +256,7 @@ def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: class RoleEnd(ASTNode): role: str - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter._role_end(self, **kwargs) @@ -263,21 +264,21 @@ def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: class CaptureStart(ASTNode): name: str - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.capture_start(self, **kwargs) @dataclass class CaptureEnd(ASTNode): name: str - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.capture_end(self, **kwargs) @dataclass class ImageBlob(ASTNode): data: str - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.image_blob(self, **kwargs) @@ -285,7 +286,7 @@ def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: class ImageUrl(ASTNode): url: str - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.image_url(self, **kwargs) @@ -293,7 +294,7 @@ def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: class AudioBlob(ASTNode): data: str - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.audio_blob(self, **kwargs) @@ -301,7 +302,7 @@ class GenAudio(ASTNode): def __init__(self, kwargs: dict[str, Any]): self.kwargs = kwargs - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.gen_audio(self, **kwargs) @@ -415,7 +416,7 @@ class LiteralNode(GrammarNode): def is_null(self) -> bool: return self.value == "" - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.text(self, **kwargs) @@ -423,7 +424,7 @@ def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: class RegexNode(GrammarNode): regex: Optional[str] - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.regex(self, **kwargs) @@ -451,7 +452,7 @@ def simplify(self) -> "GrammarNode": def children(self) -> Sequence["GrammarNode"]: return self.alternatives - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.select(self, **kwargs) @@ -474,7 +475,7 @@ def simplify(self) -> "GrammarNode": def children(self) -> Sequence["GrammarNode"]: return self.nodes - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.join(self, **kwargs) @@ -500,7 +501,7 @@ def children(self) -> Sequence["GrammarNode"]: def simplify(self) -> GrammarNode: return RepeatNode(self.node.simplify(), self.min, self.max) - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.repeat(self, **kwargs) @@ -513,7 +514,7 @@ def is_terminal(self) -> bool: # this can be used as part of bigger regexes return True - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.substring(self, **kwargs) @@ -564,7 +565,7 @@ def is_terminal(self) -> bool: def children(self) -> Sequence["GrammarNode"]: return (self.value,) - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.rule(self, **kwargs) @@ -584,7 +585,7 @@ def is_terminal(self) -> bool: # so it should never be terminal. return False - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: if self.target is None: raise ValueError("RuleRefNode target not set") return interpreter.rule(self.target) @@ -604,7 +605,7 @@ class SubgrammarNode(BaseSubgrammarNode): body: GrammarNode skip_regex: Optional[str] = None - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.subgrammar(self, **kwargs) @@ -612,7 +613,7 @@ def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: class JsonNode(BaseSubgrammarNode): schema: dict[str, Any] - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.json(self, **kwargs) @@ -620,7 +621,7 @@ def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: class LarkNode(BaseSubgrammarNode): lark_grammar: str - def _run(self, interpreter: "Interpreter[S, R]", **kwargs) -> R: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.lark(self, **kwargs) From b7650b86ff3234dfed4365fa199ff131f32794a8 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Thu, 24 Apr 2025 12:44:15 -0700 Subject: [PATCH 43/56] make ModelStream play nicely with lazy Model --- guidance/models/_base/_model.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index ee6c4d594..932f62005 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -362,32 +362,17 @@ class ModelStream: def __init__( self, model: Model, - grammar: Union["ModelStream", str, ASTNode, Function, None] = None, - timeout=5, + timeout: float = 5.0, ) -> None: """Create a model stream object that delays execution until it is iterated over.""" if model.echo: model = model.copy() model.echo = False # turn off display echoing self.model = model - self.grammar = grammar self.timeout = timeout - def __add__(self, grammar: Union[str, ASTNode]) -> Self: - """Extend this delayed chain of execution with another grammar append.""" - if self.grammar is None: - return ModelStream(self.model, grammar) - else: - return ModelStream(self.model, self.grammar + grammar) - - def _inner_run(self, model): - """This runs the model stream without iterating, and is only using internally by __iter__.""" - if isinstance(self.grammar, ModelStream): - model = self.grammar._inner_run(model) - elif self.grammar is None: - model = self.model + "" - else: - model = self.model + self.grammar + def __add__(self, other: Any) -> Self: + return ModelStream(self.model + other) def __iter__(self) -> Iterator[Model]: """Starts a thread to execute the model and grammar, yielding events as they occur.""" @@ -400,7 +385,7 @@ def __iter__(self) -> Iterator[Model]: def target(ctx): _event_queues.set(ctx[_event_queues]) try: - self._inner_run(self.model) + self.model._run_sync() events.put(None) # mark that we are done except BaseException as ex: events.put(ex) From 9e12ff7f0a0f5b8c5c1cb24168b98ea7d6233dc9 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Thu, 24 Apr 2025 12:51:17 -0700 Subject: [PATCH 44/56] call str to trigger execution for tests (todo: hide metrics behind accessor) --- tests/model_specific/test_transformers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/model_specific/test_transformers.py b/tests/model_specific/test_transformers.py index 07e332db2..f26e03254 100644 --- a/tests/model_specific/test_transformers.py +++ b/tests/model_specific/test_transformers.py @@ -41,14 +41,17 @@ def ff_prompt(lm): gpt2_noff = models.Transformers("gpt2", enable_backtrack=False, enable_ff_tokens=False) gpt2_noff += ff_prompt() + str(gpt2_noff) # Trigger execution noff_count = gpt2_noff.engine.metrics.engine_output_tokens gpt2_nobt = models.Transformers("gpt2", enable_backtrack=False) gpt2_nobt += ff_prompt() + str(gpt2_nobt) # Trigger execution nobt_count = gpt2_nobt.engine.metrics.engine_output_tokens gpt2_ff = models.Transformers("gpt2") gpt2_ff += ff_prompt() + str(gpt2_ff) # Trigger execution ff_count = gpt2_ff.engine.metrics.engine_output_tokens assert nobt_count == 2 From 439355f8e4cc5bbe49ae722bec0e51e1b19118ee Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Thu, 24 Apr 2025 15:09:43 -0700 Subject: [PATCH 45/56] Fix attribution to greenletio --- guidance/_reentrant_async.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/guidance/_reentrant_async.py b/guidance/_reentrant_async.py index 37536a147..48abbd2a4 100644 --- a/guidance/_reentrant_async.py +++ b/guidance/_reentrant_async.py @@ -1,6 +1,6 @@ """ -Heavily inspired by (read: largely stolen from) -https://github.com/miguelgrinberg/greenletio +Adapted from greenletio (https://github.com/miguelgrinberg/greenletio), used under the MIT License. +See LICENSE file or https://github.com/miguelgrinberg/greenletio/blob/main/LICENSE for details. """ import contextvars From f280ad916381e8bddfbb39a242c0b1574d06cd22 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Thu, 24 Apr 2025 15:10:14 -0700 Subject: [PATCH 46/56] Remove double import --- guidance/models/_base/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guidance/models/_base/__init__.py b/guidance/models/_base/__init__.py index 2808e0936..a8ce76d85 100644 --- a/guidance/models/_base/__init__.py +++ b/guidance/models/_base/__init__.py @@ -1,4 +1,4 @@ -from ._interpreter import Interpreter, Interpreter +from ._interpreter import Interpreter from ._model import Model from ._state import State From 6edf2763fb7f1cfb49079af91c478a1fc78c0603 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 6 May 2025 10:46:10 -0700 Subject: [PATCH 47/56] refactor before merge --- guidance/models/_openai.py | 434 +------------------------------- guidance/models/_openai_base.py | 424 +++++++++++++++++++++++++++++++ 2 files changed, 430 insertions(+), 428 deletions(-) create mode 100644 guidance/models/_openai_base.py diff --git a/guidance/models/_openai.py b/guidance/models/_openai.py index 971f3de45..b53aa22c3 100644 --- a/guidance/models/_openai.py +++ b/guidance/models/_openai.py @@ -1,428 +1,6 @@ -import base64 -import wave -from io import BytesIO -from typing import TYPE_CHECKING, AsyncIterable, Literal, Optional, Union -from copy import deepcopy - -from pydantic import BaseModel, Discriminator, Field, TypeAdapter -from typing_extensions import Annotated, assert_never - -from .._ast import ( - ASTNode, - GenAudio, - ImageBlob, - ImageUrl, - JsonNode, - LiteralNode, - RegexNode, - RoleEnd, - RoleStart, - RuleNode, -) -from .._utils import bytes_from -from ..trace import ImageOutput, OutputAttr, TextOutput -from ..trace._trace import AudioOutput -from ._base import Interpreter, Model, State - -if TYPE_CHECKING: - from openai.types.chat import ChatCompletionChunk - - -def get_role_start(role: str) -> str: - # ChatML is as good as anything - return "<|im_start|>" + role + "\n" - - -def get_role_end(role: str) -> str: - # ChatML is as good as anything - return "\n<|im_end|>\n" - - -class AssistantAudio(BaseModel): - id: str - expires_at: int = Field(exclude=True) - data: str = Field(exclude=True) - transcript: str = Field(exclude=True) - - -class AssistantAudioMessage(BaseModel): - role: Literal["assistant"] - audio: AssistantAudio - - -class TextContent(BaseModel): - type: Literal["text"] - text: str - - -class InputAudio(BaseModel): - data: str - format: str - - -class AudioContent(BaseModel): - type: Literal["input_audio"] - input_audio: InputAudio - - -class ImageUrlContentInner(BaseModel): - url: str - - -class ImageUrlContent(BaseModel): - type: Literal["image_url"] - image_url: ImageUrlContentInner - - -Content = Annotated[Union[TextContent, AudioContent, ImageUrlContent], Discriminator("type")] - - -class ContentMessage(BaseModel): - role: Literal["system", "user", "assistant"] - content: list[Content] - - -Message = Union[ContentMessage, AssistantAudioMessage] - - -class OpenAIState(State): - def __init__(self) -> None: - super().__init__() - self.messages: list[Message] = [] - self.content: list[Content] = [] - self.audio: Optional[AssistantAudio] = None - - def apply_text(self, text: str) -> None: - if len(self.content) > 0 and isinstance(self.content[-1], TextContent): - self.content[-1].text += text - else: - self.content.append(TextContent(type="text", text=text)) - - def get_active_message(self) -> Optional[Message]: - if self.active_role is None: - return None - if self.content and self.audio: - raise ValueError("Cannot have both content and audio") - if self.audio: - return AssistantAudioMessage( - role=self.active_role, - audio=self.audio, - ) - elif self.content: - return ContentMessage( - role=self.active_role, - content=self.content, - ) - else: - return None - - def __str__(self) -> str: - messages = self.messages - active_message = self.get_active_message() - if active_message is not None: - messages = messages + [active_message] - s = "" - for i, message in enumerate(messages): - s += get_role_start(message.role) - if isinstance(message, AssistantAudioMessage): - s += "[AUDIO]" - elif isinstance(message, ContentMessage): - for content in message.content: - if isinstance(content, TextContent): - s += content.text - elif isinstance(content, ImageUrlContent): - s += "[IMAGE]" # Arbitrary stringification - elif isinstance(content, AudioContent): - s += "[AUDIO]" # transcript? - else: - if TYPE_CHECKING: - assert_never(content) - raise TypeError(f"Unknown content type: {content}") - else: - if TYPE_CHECKING: - assert_never(message) - raise TypeError(f"Unknown message type: {message}") - if active_message is None or i != len(messages) - 1: - # For the sake of consistency, don't add role end for the active message - s += get_role_end(message.role) - return s - - -class OpenAIInterpreter(Interpreter[OpenAIState]): - log_probs: bool = True - - def __init__( - self, - model: str, - api_key: Optional[str] = None, - **kwargs, - ): - try: - import openai - except ImportError: - raise Exception( - "Please install the openai package version >= 1 using `pip install openai -U` in order to use guidance.models.OpenAI!" - ) - self.state = OpenAIState() - self.model = model - self.client = openai.AsyncOpenAI(api_key=api_key, **kwargs) - - def __deepcopy__(self, memo): - """Custom deepcopy to ensure client is not copied.""" - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - if k == "client": - # Don't copy the client - setattr(result, k, v) - else: - setattr(result, k, deepcopy(v, memo)) - return result - - def run(self, node: ASTNode, **kwargs) -> AsyncIterable[OutputAttr]: - # if not isinstance(node, RoleStart) and self.state.active_role is None: - # raise ValueError( - # "OpenAI models require an active role (e.g. use `with assistant(): ...`)" - # ) - return super().run(node, **kwargs) - - async def role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: - self.state.active_role = node.role - # TODO: drop this and yield nothing. We need to add this for now as a workaround for the - # fact that current vis code assumes that there is actually a role start message - yield TextOutput(value=get_role_start(node.role), is_input=True) - - async def role_end(self, node: RoleEnd, **kwargs) -> AsyncIterable[OutputAttr]: - self.state.messages.append(self.state.get_active_message()) - self.state.audio = None - self.state.content = [] - self.state.active_role = None - if False: - # I know this is weird, but this is how async generators work - yield - - async def text(self, node: LiteralNode, **kwargs) -> AsyncIterable[OutputAttr]: - self.state.apply_text(node.value) - yield TextOutput(value=node.value, is_input=True) - - async def rule(self, node: RuleNode, **kwargs) -> AsyncIterable[OutputAttr]: - if node.stop: - raise ValueError("Stop condition not yet supported for OpenAI") - if node.suffix: - raise ValueError("Suffix not yet supported for OpenAI") - if node.stop_capture: - raise ValueError("Save stop text not yet supported for OpenAI") - - kwargs = kwargs.copy() - if node.temperature: - kwargs["temperature"] = node.temperature - if node.max_tokens: - kwargs["max_tokens"] = node.max_tokens - - chunks = self.run(node.value, **kwargs) - if node.capture: - buffered_text = "" - async for chunk in chunks: - # TODO: this isinstance check is pretty darn fragile. - # ~there must be a better way~ - if isinstance(chunk, TextOutput): - buffered_text += chunk.value - yield chunk - yield self.state.apply_capture( - name=node.capture, - value=buffered_text, - log_prob=1, # TODO - is_append=node.list_append, - ) - else: - async for chunk in chunks: - yield chunk - - def regex(self, node: RegexNode, **kwargs) -> AsyncIterable[OutputAttr]: - if node.regex is not None: - raise ValueError("Regex not yet supported for OpenAI") - # We're in unconstrained mode now. - return self._run(**kwargs) - - def json(self, node: JsonNode, **kwargs) -> AsyncIterable[OutputAttr]: - return self._run( - response_format={ - "type": "json_schema", - "json_schema": { - "name": "json_schema", # TODO? - "schema": node.schema, - "strict": True, - }, - }, - **kwargs, - ) - - async def _run(self, **kwargs) -> AsyncIterable[OutputAttr]: - if self.state.active_role is None: - # Should never happen? - raise ValueError( - "OpenAI models require chat blocks (e.g. use `with assistant(): ...`)" - ) - if self.state.active_role != "assistant": - raise ValueError( - "OpenAI models can only generate as the assistant (i.e. inside of `with assistant(): ...`)" - ) - if self.state.content: - raise ValueError( - f"OpenAI models do not support pre-filled assistant messages: got data {self.state.content}." - ) - - async with await self.client.chat.completions.create( - model=self.model, - messages=TypeAdapter(list[Message]).dump_python(self.state.messages), # type: ignore[arg-type] - logprobs=self.log_probs, - stream=True, - **kwargs, - ) as chunks: - async for output in self._handle_stream(chunks): - yield output - - async def _handle_stream( - self, chunks: AsyncIterable["ChatCompletionChunk"] - ) -> AsyncIterable[OutputAttr]: - audio: Optional[AssistantAudio] = None - async for chunk in chunks: - choice = chunk.choices[0] - delta = choice.delta - if delta.content is not None: - assert audio is None - content = delta.content - if len(content) == 0: - continue - self.state.apply_text(content) - if choice.logprobs is not None: - # TODO: actually get tokens from this and be less lazy - prob = 2.718 ** choice.logprobs.content[0].logprob - else: - prob = float("nan") - yield TextOutput(value=delta.content, is_generated=True, prob=prob) - elif getattr(delta, "audio", None) is not None: - transcript_chunk: Optional[str] = None - if audio is None: - assert delta.audio.get("id") is not None - audio = AssistantAudio( - id=delta.audio["id"], - expires_at=delta.audio.get("expires_at", 0), # ? - transcript=delta.audio.get("transcript", ""), - data=delta.audio.get("data", ""), - ) - transcript_chunk = delta.audio.get("transcript") - else: - assert delta.audio.get("id") is None or delta.audio["id"] == audio.id - if delta.audio.get("data") is not None: - audio.data += delta.audio["data"] - if delta.audio.get("transcript") is not None: - audio.transcript += delta.audio["transcript"] - transcript_chunk = delta.audio["transcript"] - if delta.audio.get("expires_at") is not None: - assert audio.expires_at == 0 - audio.expires_at = delta.audio["expires_at"] - if transcript_chunk is not None: - # Why not give the users some transcript? :) - yield TextOutput( - value=delta.audio["transcript"], - is_generated=True, - ) - elif delta.function_call is not None: - raise NotImplementedError("Function calling not yet supported for OpenAI") - elif delta.tool_calls is not None: - raise NotImplementedError("Tool calling not yet supported for OpenAI") - elif delta.refusal is not None: - raise ValueError(f"OpenAI refused the request: {delta.refusal}") - - if choice.finish_reason is not None: - break - - if audio is not None: - assert self.state.audio is None - self.state.audio = audio - # Create an in-memory WAV file - wav_buffer = BytesIO() - with wave.open(wav_buffer, "wb") as wav_file: - wav_file.setnchannels(1) - wav_file.setsampwidth(2) # PCM16 = 2 bytes per sample - wav_file.setframerate(22050) # A guess - wav_file.writeframes(base64.b64decode(audio.data)) - - # Get WAV bytes - wav_bytes = wav_buffer.getvalue() - yield AudioOutput(value=base64.b64encode(wav_bytes).decode(), is_input=False) - - def __deepcopy__(self, memo): - """Custom deepcopy to ensure client is not copied.""" - cls = self.__class__ - result = cls.__new__(cls) - memo[id(self)] = result - for k, v in self.__dict__.items(): - if k == "client": - # Don't copy the client - setattr(result, k, v) - else: - setattr(result, k, deepcopy(v, memo)) - return result - - -class OpenAIImageInterpreter(OpenAIInterpreter): - async def image_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAttr]: - try: - import PIL.Image - except ImportError: - raise Exception( - "Please install the Pillow package `pip install Pillow` in order to use images with OpenAI!" - ) - - image_bytes = base64.b64decode(node.data) - with PIL.Image.open(BytesIO(image_bytes)) as pil_image: - # Use PIL to infer file format - # TODO: just store format on ImageOutput type - format = pil_image.format - if format is None: - raise ValueError(f"Cannot upload image with unknown format") - - mime_type = f"image/{format.lower()}" - self.state.content.append( - {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{node.data}"}} - ) - yield ImageOutput(value=node.data, input=True) - - async def image_url(self, node: ImageUrl, **kwargs) -> AsyncIterable[OutputAttr]: - self.state.content.append({"type": "image_url", "image_url": {"url": node.url}}) - image_bytes = bytes_from(node.url, allow_local=False) - base64_string = base64.b64encode(image_bytes).decode("utf-8") - yield ImageOutput(value=base64_string, input=True) - - -class OpenAIAudioInterpreter(OpenAIInterpreter): - log_probs: bool = False - - async def audio_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAttr]: - format = "wav" # TODO: infer from node - self.state.content.append( - AudioContent( - type="input_audio", - input_audio=InputAudio( - data=node.data, - format=format, - ), - ) - ) - yield AudioOutput(value=node.data, format=format, input=True) - - async def gen_audio(self, node: GenAudio, **kwargs) -> AsyncIterable[OutputAttr]: - return self._run( - modalities=["text", "audio"], # Has to be both? - audio={ - "voice": node.kwargs.get("voice", "alloy"), - "format": "pcm16", # Has to be pcm16 for streaming - }, - ) - +from typing import Optional +from ._base import Model +from ._openai_base import BaseOpenAIInterpreter, OpenAIAudioMixin, OpenAIImageMixin class OpenAI(Model): def __init__( @@ -449,11 +27,11 @@ def __init__( """ if "audio-preview" in model: - interpreter_cls = OpenAIAudioInterpreter + interpreter_cls = type("OpenAIAudioInterpreter", (BaseOpenAIInterpreter, OpenAIAudioMixin), {}) elif model.startswith("gpt-4o") or model.startswith("o1"): - interpreter_cls = OpenAIImageInterpreter + interpreter_cls = type("OpenAIImageInterpreter", (BaseOpenAIInterpreter, OpenAIImageMixin), {}) else: - interpreter_cls = OpenAIInterpreter + interpreter_cls = BaseOpenAIInterpreter super().__init__( interpreter=interpreter_cls(model, api_key=api_key, **kwargs), echo=echo diff --git a/guidance/models/_openai_base.py b/guidance/models/_openai_base.py new file mode 100644 index 000000000..7b1e59ffe --- /dev/null +++ b/guidance/models/_openai_base.py @@ -0,0 +1,424 @@ +import base64 +import wave +from io import BytesIO +from typing import TYPE_CHECKING, AsyncIterable, Literal, Optional, Union +from copy import deepcopy + +from pydantic import BaseModel, Discriminator, Field, TypeAdapter +from typing_extensions import Annotated, assert_never + +from .._ast import ( + ASTNode, + GenAudio, + ImageBlob, + ImageUrl, + JsonNode, + LiteralNode, + RegexNode, + RoleEnd, + RoleStart, + RuleNode, +) +from .._utils import bytes_from +from ..trace import ImageOutput, OutputAttr, TextOutput +from ..trace._trace import AudioOutput +from ._base import Interpreter, State + +if TYPE_CHECKING: + from openai.types.chat import ChatCompletionChunk + + +def get_role_start(role: str) -> str: + # ChatML is as good as anything + return "<|im_start|>" + role + "\n" + + +def get_role_end(role: str) -> str: + # ChatML is as good as anything + return "\n<|im_end|>\n" + + +class AssistantAudio(BaseModel): + id: str + expires_at: int = Field(exclude=True) + data: str = Field(exclude=True) + transcript: str = Field(exclude=True) + + +class AssistantAudioMessage(BaseModel): + role: Literal["assistant"] + audio: AssistantAudio + + +class TextContent(BaseModel): + type: Literal["text"] + text: str + + +class InputAudio(BaseModel): + data: str + format: str + + +class AudioContent(BaseModel): + type: Literal["input_audio"] + input_audio: InputAudio + + +class ImageUrlContentInner(BaseModel): + url: str + + +class ImageUrlContent(BaseModel): + type: Literal["image_url"] + image_url: ImageUrlContentInner + + +Content = Annotated[Union[TextContent, AudioContent, ImageUrlContent], Discriminator("type")] + + +class ContentMessage(BaseModel): + role: Literal["system", "user", "assistant"] + content: list[Content] + + +Message = Union[ContentMessage, AssistantAudioMessage] + + +class OpenAIState(State): + def __init__(self) -> None: + super().__init__() + self.messages: list[Message] = [] + self.content: list[Content] = [] + self.audio: Optional[AssistantAudio] = None + + def apply_text(self, text: str) -> None: + if len(self.content) > 0 and isinstance(self.content[-1], TextContent): + self.content[-1].text += text + else: + self.content.append(TextContent(type="text", text=text)) + + def get_active_message(self) -> Optional[Message]: + if self.active_role is None: + return None + if self.content and self.audio: + raise ValueError("Cannot have both content and audio") + if self.audio: + return AssistantAudioMessage( + role=self.active_role, + audio=self.audio, + ) + elif self.content: + return ContentMessage( + role=self.active_role, + content=self.content, + ) + else: + return None + + def __str__(self) -> str: + messages = self.messages + active_message = self.get_active_message() + if active_message is not None: + messages = messages + [active_message] + s = "" + for i, message in enumerate(messages): + s += get_role_start(message.role) + if isinstance(message, AssistantAudioMessage): + s += "[AUDIO]" + elif isinstance(message, ContentMessage): + for content in message.content: + if isinstance(content, TextContent): + s += content.text + elif isinstance(content, ImageUrlContent): + s += "[IMAGE]" # Arbitrary stringification + elif isinstance(content, AudioContent): + s += "[AUDIO]" # transcript? + else: + if TYPE_CHECKING: + assert_never(content) + raise TypeError(f"Unknown content type: {content}") + else: + if TYPE_CHECKING: + assert_never(message) + raise TypeError(f"Unknown message type: {message}") + if active_message is None or i != len(messages) - 1: + # For the sake of consistency, don't add role end for the active message + s += get_role_end(message.role) + return s + + +class BaseOpenAIInterpreter(Interpreter[OpenAIState]): + log_probs: bool = True + + def __init__( + self, + model: str, + api_key: Optional[str] = None, + **kwargs, + ): + try: + import openai + except ImportError: + raise Exception( + "Please install the openai package version >= 1 using `pip install openai -U` in order to use guidance.models.OpenAI!" + ) + self.state = OpenAIState() + self.model = model + self.client = openai.AsyncOpenAI(api_key=api_key, **kwargs) + + def __deepcopy__(self, memo): + """Custom deepcopy to ensure client is not copied.""" + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k == "client": + # Don't copy the client + setattr(result, k, v) + else: + setattr(result, k, deepcopy(v, memo)) + return result + + def run(self, node: ASTNode, **kwargs) -> AsyncIterable[OutputAttr]: + # if not isinstance(node, RoleStart) and self.state.active_role is None: + # raise ValueError( + # "OpenAI models require an active role (e.g. use `with assistant(): ...`)" + # ) + return super().run(node, **kwargs) + + async def role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: + self.state.active_role = node.role + # TODO: drop this and yield nothing. We need to add this for now as a workaround for the + # fact that current vis code assumes that there is actually a role start message + yield TextOutput(value=get_role_start(node.role), is_input=True) + + async def role_end(self, node: RoleEnd, **kwargs) -> AsyncIterable[OutputAttr]: + self.state.messages.append(self.state.get_active_message()) + self.state.audio = None + self.state.content = [] + self.state.active_role = None + if False: + # I know this is weird, but this is how async generators work + yield + + async def text(self, node: LiteralNode, **kwargs) -> AsyncIterable[OutputAttr]: + self.state.apply_text(node.value) + yield TextOutput(value=node.value, is_input=True) + + async def rule(self, node: RuleNode, **kwargs) -> AsyncIterable[OutputAttr]: + if node.stop: + raise ValueError("Stop condition not yet supported for OpenAI") + if node.suffix: + raise ValueError("Suffix not yet supported for OpenAI") + if node.stop_capture: + raise ValueError("Save stop text not yet supported for OpenAI") + + kwargs = kwargs.copy() + if node.temperature: + kwargs["temperature"] = node.temperature + if node.max_tokens: + kwargs["max_tokens"] = node.max_tokens + + chunks = self.run(node.value, **kwargs) + if node.capture: + buffered_text = "" + async for chunk in chunks: + # TODO: this isinstance check is pretty darn fragile. + # ~there must be a better way~ + if isinstance(chunk, TextOutput): + buffered_text += chunk.value + yield chunk + yield self.state.apply_capture( + name=node.capture, + value=buffered_text, + log_prob=1, # TODO + is_append=node.list_append, + ) + else: + async for chunk in chunks: + yield chunk + + def regex(self, node: RegexNode, **kwargs) -> AsyncIterable[OutputAttr]: + if node.regex is not None: + raise ValueError("Regex not yet supported for OpenAI") + # We're in unconstrained mode now. + return self._run(**kwargs) + + def json(self, node: JsonNode, **kwargs) -> AsyncIterable[OutputAttr]: + return self._run( + response_format={ + "type": "json_schema", + "json_schema": { + "name": "json_schema", # TODO? + "schema": node.schema, + "strict": True, + }, + }, + **kwargs, + ) + + async def _run(self, **kwargs) -> AsyncIterable[OutputAttr]: + if self.state.active_role is None: + # Should never happen? + raise ValueError( + "OpenAI models require chat blocks (e.g. use `with assistant(): ...`)" + ) + if self.state.active_role != "assistant": + raise ValueError( + "OpenAI models can only generate as the assistant (i.e. inside of `with assistant(): ...`)" + ) + if self.state.content: + raise ValueError( + f"OpenAI models do not support pre-filled assistant messages: got data {self.state.content}." + ) + + async with await self.client.chat.completions.create( + model=self.model, + messages=TypeAdapter(list[Message]).dump_python(self.state.messages), # type: ignore[arg-type] + logprobs=self.log_probs, + stream=True, + **kwargs, + ) as chunks: + async for output in self._handle_stream(chunks): + yield output + + async def _handle_stream( + self, chunks: AsyncIterable["ChatCompletionChunk"] + ) -> AsyncIterable[OutputAttr]: + audio: Optional[AssistantAudio] = None + async for chunk in chunks: + choice = chunk.choices[0] + delta = choice.delta + if delta.content is not None: + assert audio is None + content = delta.content + if len(content) == 0: + continue + self.state.apply_text(content) + if choice.logprobs is not None: + # TODO: actually get tokens from this and be less lazy + prob = 2.718 ** choice.logprobs.content[0].logprob + else: + prob = float("nan") + yield TextOutput(value=delta.content, is_generated=True, prob=prob) + elif getattr(delta, "audio", None) is not None: + transcript_chunk: Optional[str] = None + if audio is None: + assert delta.audio.get("id") is not None + audio = AssistantAudio( + id=delta.audio["id"], + expires_at=delta.audio.get("expires_at", 0), # ? + transcript=delta.audio.get("transcript", ""), + data=delta.audio.get("data", ""), + ) + transcript_chunk = delta.audio.get("transcript") + else: + assert delta.audio.get("id") is None or delta.audio["id"] == audio.id + if delta.audio.get("data") is not None: + audio.data += delta.audio["data"] + if delta.audio.get("transcript") is not None: + audio.transcript += delta.audio["transcript"] + transcript_chunk = delta.audio["transcript"] + if delta.audio.get("expires_at") is not None: + assert audio.expires_at == 0 + audio.expires_at = delta.audio["expires_at"] + if transcript_chunk is not None: + # Why not give the users some transcript? :) + yield TextOutput( + value=delta.audio["transcript"], + is_generated=True, + ) + elif delta.function_call is not None: + raise NotImplementedError("Function calling not yet supported for OpenAI") + elif delta.tool_calls is not None: + raise NotImplementedError("Tool calling not yet supported for OpenAI") + elif delta.refusal is not None: + raise ValueError(f"OpenAI refused the request: {delta.refusal}") + + if choice.finish_reason is not None: + break + + if audio is not None: + assert self.state.audio is None + self.state.audio = audio + # Create an in-memory WAV file + wav_buffer = BytesIO() + with wave.open(wav_buffer, "wb") as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) # PCM16 = 2 bytes per sample + wav_file.setframerate(22050) # A guess + wav_file.writeframes(base64.b64decode(audio.data)) + + # Get WAV bytes + wav_bytes = wav_buffer.getvalue() + yield AudioOutput(value=base64.b64encode(wav_bytes).decode(), is_input=False) + + def __deepcopy__(self, memo): + """Custom deepcopy to ensure client is not copied.""" + cls = self.__class__ + result = cls.__new__(cls) + memo[id(self)] = result + for k, v in self.__dict__.items(): + if k == "client": + # Don't copy the client + setattr(result, k, v) + else: + setattr(result, k, deepcopy(v, memo)) + return result + + +class OpenAIImageMixin(BaseOpenAIInterpreter): + async def image_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAttr]: + try: + import PIL.Image + except ImportError: + raise Exception( + "Please install the Pillow package `pip install Pillow` in order to use images with OpenAI!" + ) + + image_bytes = base64.b64decode(node.data) + with PIL.Image.open(BytesIO(image_bytes)) as pil_image: + # Use PIL to infer file format + # TODO: just store format on ImageOutput type + format = pil_image.format + if format is None: + raise ValueError(f"Cannot upload image with unknown format") + + mime_type = f"image/{format.lower()}" + self.state.content.append( + {"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{node.data}"}} + ) + yield ImageOutput(value=node.data, input=True) + + async def image_url(self, node: ImageUrl, **kwargs) -> AsyncIterable[OutputAttr]: + self.state.content.append({"type": "image_url", "image_url": {"url": node.url}}) + image_bytes = bytes_from(node.url, allow_local=False) + base64_string = base64.b64encode(image_bytes).decode("utf-8") + yield ImageOutput(value=base64_string, input=True) + + +class OpenAIAudioMixin(BaseOpenAIInterpreter): + log_probs: bool = False + + async def audio_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAttr]: + format = "wav" # TODO: infer from node + self.state.content.append( + AudioContent( + type="input_audio", + input_audio=InputAudio( + data=node.data, + format=format, + ), + ) + ) + yield AudioOutput(value=node.data, format=format, input=True) + + async def gen_audio(self, node: GenAudio, **kwargs) -> AsyncIterable[OutputAttr]: + return self._run( + modalities=["text", "audio"], # Has to be both? + audio={ + "voice": node.kwargs.get("voice", "alloy"), + "format": "pcm16", # Has to be pcm16 for streaming + }, + ) From d4c8b0efdf70ecce1be9b014a43d0c4a60a68e0a Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 6 May 2025 11:16:50 -0700 Subject: [PATCH 48/56] async openai --- guidance/models/_openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guidance/models/_openai.py b/guidance/models/_openai.py index ec802373f..ed00af9ae 100644 --- a/guidance/models/_openai.py +++ b/guidance/models/_openai.py @@ -32,7 +32,7 @@ def __init__( raise Exception( "Please install the openai package version >= 1 using `pip install openai -U` in order to use guidance.models.OpenAI!" ) - client = openai.OpenAI(api_key, **kwargs) + client = openai.AsyncOpenAI(api_key, **kwargs) super().__init__(model=model, client=client) From 4ff3d3d33eae3a7a40ef1284a660ef52688945d9 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 6 May 2025 12:12:38 -0700 Subject: [PATCH 49/56] clean up openai a bit --- guidance/models/_openai.py | 16 ++-------------- guidance/models/_openai_base.py | 2 +- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/guidance/models/_openai.py b/guidance/models/_openai.py index ed00af9ae..7fd81b856 100644 --- a/guidance/models/_openai.py +++ b/guidance/models/_openai.py @@ -1,20 +1,8 @@ -from typing import TYPE_CHECKING, Callable, Iterator, Optional, Union +from typing import Optional -from pydantic import TypeAdapter - - -from .._ast import ( - JsonNode, - RuleNode, -) -from ..trace import OutputAttr from ._base import Model - from ._openai_base import ( BaseOpenAIInterpreter, - AudioContent, - OpenAIState, - Message, OpenAIImageMixin, OpenAIAudioMixin, ) @@ -32,7 +20,7 @@ def __init__( raise Exception( "Please install the openai package version >= 1 using `pip install openai -U` in order to use guidance.models.OpenAI!" ) - client = openai.AsyncOpenAI(api_key, **kwargs) + client = openai.AsyncOpenAI(api_key=api_key, **kwargs) super().__init__(model=model, client=client) diff --git a/guidance/models/_openai_base.py b/guidance/models/_openai_base.py index 0eb4d6bef..eb37508d6 100644 --- a/guidance/models/_openai_base.py +++ b/guidance/models/_openai_base.py @@ -157,7 +157,7 @@ class BaseOpenAIInterpreter(Interpreter[OpenAIState]): def __init__( self, model: str, - client: "openai.OpenAI", + client: "openai.AsyncOpenAI", ): try: import openai From 4125e01b4f660219e2adf36d90ab16ffbeed6dc1 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 6 May 2025 12:13:27 -0700 Subject: [PATCH 50/56] fix openai with Concatenate ast --- guidance/models/_openai_base.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/guidance/models/_openai_base.py b/guidance/models/_openai_base.py index eb37508d6..8759c36d6 100644 --- a/guidance/models/_openai_base.py +++ b/guidance/models/_openai_base.py @@ -169,13 +169,6 @@ def __init__( self.model = model self.client = client - def run(self, node: ASTNode, **kwargs) -> AsyncIterable[OutputAttr]: - if not isinstance(node, RoleStart) and self.state.active_role is None: - raise ValueError( - "OpenAI models require an active role (e.g. use `with assistant(): ...`)" - ) - return super().run(node, **kwargs) - async def role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: self.state.active_role = node.role # TODO: drop this and yield nothing. We need to add this for now as a workaround for the From 5a649ed6aad4c98f5ea60d6aa866d852ae86ee64 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 6 May 2025 12:14:24 -0700 Subject: [PATCH 51/56] bring vllm up to speed --- guidance/models/experimental/__init__.py | 2 +- guidance/models/experimental/_vllm.py | 83 +++++++++++++++--------- 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/guidance/models/experimental/__init__.py b/guidance/models/experimental/__init__.py index 09d1311b5..5e0180708 100644 --- a/guidance/models/experimental/__init__.py +++ b/guidance/models/experimental/__init__.py @@ -1 +1 @@ -from ._vllm import VLLMModel \ No newline at end of file +from ._vllm import VLLMModel diff --git a/guidance/models/experimental/_vllm.py b/guidance/models/experimental/_vllm.py index 96cc26202..30e300745 100644 --- a/guidance/models/experimental/_vllm.py +++ b/guidance/models/experimental/_vllm.py @@ -1,10 +1,12 @@ -from typing import Iterator, Optional, TYPE_CHECKING +from typing import AsyncIterable, Optional, TYPE_CHECKING import wave import base64 from io import BytesIO from copy import deepcopy from pydantic import TypeAdapter + if TYPE_CHECKING: + import openai from openai.types.chat import ChatCompletionChunk from ..._ast import GrammarNode, RoleStart, RoleEnd, ASTNode, LiteralNode @@ -13,15 +15,14 @@ from .._openai_base import OpenAIState, AssistantAudio, Message, get_role_start from .._base import Model, Interpreter + class BaseOpenAIInterpreterForVLLM(Interpreter[OpenAIState]): log_probs: bool = True def __init__( self, model: str, - base_url: Optional[str] = None, - api_key: Optional[str] = None, - **kwargs, + client: "openai.AsyncOpenAI", ): try: import openai @@ -31,33 +32,28 @@ def __init__( ) self.state = OpenAIState() self.model = model - self.client = openai.OpenAI(base_url=base_url, api_key=api_key, **kwargs) + self.client = client - def role_start(self, node: RoleStart, **kwargs) -> Iterator[OutputAttr]: + async def role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: self.state.active_role = node.role # TODO: drop this and yield nothing. We need to add this for now as a workaround for the # fact that current vis code assumes that there is actually a role start message yield TextOutput(value=get_role_start(node.role), is_input=True) - def role_end(self, node: RoleEnd, **kwargs) -> Iterator[OutputAttr]: + async def role_end(self, node: RoleEnd, **kwargs) -> AsyncIterable[OutputAttr]: self.state.messages.append(self.state.get_active_message()) self.state.audio = None self.state.content = [] self.state.active_role = None - yield from () + if False: + # I know this is weird, but this is how async generators work + yield - def text(self, node: LiteralNode, **kwargs) -> Iterator[OutputAttr]: + async def text(self, node: LiteralNode, **kwargs) -> AsyncIterable[OutputAttr]: self.state.apply_text(node.value) yield TextOutput(value=node.value, is_input=True) - def run(self, node: ASTNode, **kwargs) -> Iterator[OutputAttr]: - if not isinstance(node, RoleStart) and self.state.active_role is None: - raise ValueError( - "OpenAI models require an active role (e.g. use `with assistant(): ...`)" - ) - return super().run(node, **kwargs) - - def _run(self, **kwargs) -> Iterator[OutputAttr]: + async def _run(self, **kwargs) -> AsyncIterable[OutputAttr]: if self.state.active_role is None: # Should never happen? raise ValueError( @@ -72,21 +68,27 @@ def _run(self, **kwargs) -> Iterator[OutputAttr]: f"OpenAI models do not support pre-filled assistant messages: got data {self.state.content}." ) - with self.client.chat.completions.create( + async with await self.client.chat.completions.create( model=self.model, messages=TypeAdapter(list[Message]).dump_python(self.state.messages), # type: ignore[arg-type] logprobs=self.log_probs, stream=True, **kwargs, ) as chunks: - yield from self._handle_stream(chunks) + async for output in self._handle_stream(chunks): + yield output - def _handle_stream( - self, chunks: Iterator["ChatCompletionChunk"] - ) -> Iterator[OutputAttr]: + async def _handle_stream( + self, chunks: AsyncIterable["ChatCompletionChunk"] + ) -> AsyncIterable[OutputAttr]: audio: Optional[AssistantAudio] = None - for chunk in chunks: - choice = chunk.choices[0] + async for chunk in chunks: + try: + choice = chunk.choices[0] + except IndexError: + # TODO: azure seems to return empty choices sometimes (on first chunk?) + # Need to make this more robust + continue delta = choice.delta if delta.content is not None: assert audio is None @@ -94,7 +96,7 @@ def _handle_stream( if len(content) == 0: continue self.state.apply_text(content) - if choice.logprobs is not None: + if getattr(choice, "logprobs", None) is not None: # TODO: actually get tokens from this and be less lazy prob = 2.718 ** choice.logprobs.content[0].logprob else: @@ -165,11 +167,12 @@ def __deepcopy__(self, memo): setattr(result, k, deepcopy(v, memo)) return result + class VLLMInterpreter(BaseOpenAIInterpreterForVLLM): - def grammar(self, node: GrammarNode, **kwargs) -> Iterator[OutputAttr]: + async def grammar(self, node: GrammarNode, **kwargs) -> AsyncIterable[OutputAttr]: buffer: str = "" - for attr in self._run( - extra_body = dict( + async for attr in self._run( + extra_body=dict( guided_decoding_backend="guidance", guided_grammar=node.ll_grammar(), ) @@ -190,13 +193,29 @@ def grammar(self, node: GrammarNode, **kwargs) -> Iterator[OutputAttr]: assert isinstance(log_probs, list) assert len(value) == len(log_probs) for v, l in zip(value, log_probs): - yield self.state.apply_capture(name=name, value=v, log_prob=l, is_append=True) + yield self.state.apply_capture( + name=name, value=v, log_prob=l, is_append=True + ) else: - yield self.state.apply_capture(name=name, value=value, log_prob=log_probs, is_append=False) + yield self.state.apply_capture( + name=name, value=value, log_prob=log_probs, is_append=False + ) + class VLLMModel(Model): - def __init__(self, model: str, echo=True, **kwargs): + def __init__( + self, model: str, base_url: str, api_key: Optional[str] = None, echo: bool = True, **kwargs + ): + try: + import openai + except ImportError: + raise Exception( + "Please install the openai package version >= 1 using `pip install openai -U` in order to use guidance.models.experimenta.VLLMModel!" + ) super().__init__( - interpreter=VLLMInterpreter(model=model, **kwargs), + interpreter=VLLMInterpreter( + model=model, + client=openai.AsyncOpenAI(base_url=base_url, api_key=api_key, **kwargs), + ), echo=echo, ) From 1650092aa569e26d650421428022fd759fc1cb2c Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 6 May 2025 15:19:30 -0700 Subject: [PATCH 52/56] regain my sanity -- refactor Model __init__ with dataclass --- guidance/models/_base/_model.py | 81 ++++++++++++++++----------------- 1 file changed, 39 insertions(+), 42 deletions(-) diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index 932f62005..fdfe2a539 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -6,37 +6,23 @@ from contextvars import ContextVar, copy_context from copy import deepcopy from typing import TYPE_CHECKING, Any, Iterator, Optional, TypeVar, Union, Sequence -import warnings -from ..._reentrant_async import sync_to_reentrant_async, reentrant_await, run_async_coroutine_in_bg_async - from typing_extensions import Self from ..._ast import ( ASTNode, - Concatenate, Function, AsyncFunction, - GenAudio, - GrammarNode, - ImageBlob, - ImageUrl, - LiteralNode, - RoleEnd, - RoleStart, CaptureStart, CaptureEnd, _parse_tags, ) from ...trace import ( - ImageInput, - LiteralInput, NodeAttr, - RoleCloserInput, - RoleOpenerInput, - StatelessGuidanceInput, TraceNode, ) -from ...trace._trace import AudioInput +from ...visual import TraceMessage +from ..._reentrant_async import sync_to_reentrant_async, reentrant_await, run_async_coroutine_in_bg_async + from ._interpreter import Interpreter from ._state import State @@ -63,22 +49,45 @@ def _gen_id(): S = TypeVar("S", bound=State) D = TypeVar("D", bound=Any) +from dataclasses import dataclass, field, InitVar + +@dataclass class Model: - def __init__( - self, - interpreter: Interpreter[S], - echo: bool = True, - ) -> None: - self.echo = echo + interpreter: InitVar[Interpreter[S]] + echo: bool = True + + # Private init attributes + _interpreter: Interpreter = field(init=False) + _parent: Optional["Model"] = None + _pending: Union[None, ASTNode, Function] = None + _active_blocks: tuple["Block", ...] = () + + # Private non-init attributes + _parent_id: Optional[int] = field(init=False, default=None) + _id: int = field(init=False, default_factory=_gen_id) + _trace_nodes: set[TraceNode] = field(init=False, default_factory=set) + + def __post_init__(self, interpreter: Interpreter) -> None: self._interpreter = interpreter - self._pending: Union[None, ASTNode, Function] = None - self._active_blocks: tuple[Block, ...] = () + # Set the parent ID if we have a parent + if self._parent is not None: + self._parent_id = self._parent._id - self._parent: Optional["Model"] = None - self._parent_id: Optional[int] = None - self._id: int = _gen_id() - self._trace_nodes: set[TraceNode] = set() - self._update_trace_node(self._id, self._parent_id, None) + def copy(self) -> Self: + obj = object.__new__(self.__class__) + obj.__dict__.update(self.__dict__) + # Use the base-class's __init__ to set up the new object + # TODO: if we can move to having just the one Model class, + # we can replace this all with a simple `dataclasses.replace(self, ...)` + Model.__init__( + obj, + interpreter=deepcopy(self._interpreter), + # TODO: should this be our parent? Or is the copy really our child? + _parent=self, + _pending=self._pending, + _active_blocks=self._active_blocks, + ) + return obj def _update_trace_node( self, identifier: int, parent_id: Optional[int], node_attr: Optional[NodeAttr] = None @@ -155,18 +164,6 @@ def _apply_blocks(self) -> None: self._add_to_pending(opener) self._active_blocks = tuple(new_active_blocks) - def copy(self) -> Self: - obj = object.__new__(self.__class__) - obj.__dict__.update(self.__dict__) - - obj._interpreter = deepcopy(self._interpreter) - obj._id = _gen_id() - obj._parent_id = self._id - obj._trace_nodes = set() - obj._parent = self - obj._update_trace_node(obj._id, obj._parent_id, None) - return obj - def __str__(self) -> str: return str(self._get_state()) From 3a29a83c2f2ac05871aa3ca3eecd4a15ba43c4d2 Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 6 May 2025 15:29:47 -0700 Subject: [PATCH 53/56] restore vis support --- guidance/models/_base/_model.py | 35 ++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index fdfe2a539..07f288024 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -92,21 +92,27 @@ def copy(self) -> Self: def _update_trace_node( self, identifier: int, parent_id: Optional[int], node_attr: Optional[NodeAttr] = None ) -> None: - # from ...registry import get_trace_handler, get_renderer - - # trace_handler = get_trace_handler() - # trace_node = trace_handler.update_node(identifier, parent_id, node_attr) - # self._trace_nodes.add(trace_node) - # if self.echo: - # get_renderer().update( - # TraceMessage( - # trace_id=identifier, - # parent_trace_id=parent_id, - # node_attr=node_attr, - # ), - # ) + from ...registry import get_trace_handler, get_renderer + + trace_handler = get_trace_handler() + trace_node = trace_handler.update_node(identifier, parent_id, node_attr) + self._trace_nodes.add(trace_node) + if self.echo: + get_renderer().update( + TraceMessage( + trace_id=identifier, + parent_trace_id=parent_id, + node_attr=node_attr, + ), + ) pass + def _increment_trace_id(self) -> None: + # This is a bit of a hack to get the trace ids working (only one output attr is allowed per id, so we need to increment.) + # Parent will be the real parent, so this is all a bit of a mess. TODO: allow multiple output attrs per id + self._parent_id = self._id + self._id = _gen_id() + def _add_to_pending(self, item: Union[ASTNode, Function]) -> None: if self._pending is None: self._pending = item @@ -271,6 +277,7 @@ def run_batched(self, items: Sequence[Union[str, Function, AsyncFunction, ASTNod return reentrant_await(self.run_batched_async(items)) async def _run(self) -> None: + # TODO: trace `InputAttr`s async def inner(): new_self = self.copy() # may be some pending blocks @@ -314,10 +321,10 @@ def _run_sync(self) -> None: async def _run_node(self, node: ASTNode) -> None: async for output_attr in self._interpreter.run(node): + self._increment_trace_id() self._update_trace_node(self._id, self._parent_id, output_attr) # Stream current model state self._send_to_event_queue() - return self async def _get_state_async(self) -> State: """Get the state of the model.""" From 4bd5c2032e663d7515dcaa5fb73885c60a30800f Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 6 May 2025 15:50:08 -0700 Subject: [PATCH 54/56] generalize interpreter run to yield InputAttr in addition to OutputAttr --- guidance/_ast.py | 39 +++++++++++++++++++++------ guidance/models/_base/_interpreter.py | 26 +++--------------- guidance/models/_base/_model.py | 4 +-- 3 files changed, 37 insertions(+), 32 deletions(-) diff --git a/guidance/_ast.py b/guidance/_ast.py index c9c63c5e3..4aec11322 100644 --- a/guidance/_ast.py +++ b/guidance/_ast.py @@ -18,7 +18,9 @@ from typing_extensions import assert_never from ._parser import ByteParser, ByteParserException -from .trace import OutputAttr +from .trace import InputAttr, OutputAttr, RoleOpenerInput, RoleCloserInput + +NodeAttr = Union[InputAttr, OutputAttr] if TYPE_CHECKING: from .models._base import Interpreter, State @@ -199,7 +201,7 @@ async def __radd__(model): class ASTNode(ABC): @abstractmethod - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[NodeAttr]: pass def simplify(self) -> "ASTNode": @@ -235,8 +237,25 @@ def null(cls) -> "ASTNode": class Concatenate(ASTNode): nodes: tuple[ASTNode, ...] - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: - return interpreter.concatenate(self, **kwargs) + async def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[NodeAttr]: + buffer: Optional[GrammarNode] = None + for child in self: + assert not isinstance(child, Concatenate) # iter should be flat + if isinstance(child, GrammarNode): + if buffer is None: + buffer = child + else: + buffer = buffer + child + else: + if buffer is not None: + async for attr in interpreter.run(buffer, **kwargs): + yield attr + buffer = None + async for attr in interpreter.run(child, **kwargs): + yield attr + if buffer is not None: + async for attr in interpreter.run(buffer, **kwargs): + yield attr def __iter__(self) -> Iterator[ASTNode]: for node in self.nodes: @@ -249,16 +268,20 @@ def __iter__(self) -> Iterator[ASTNode]: class RoleStart(ASTNode): role: str - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: - return interpreter._role_start(self, **kwargs) + async def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[NodeAttr]: + yield RoleOpenerInput(name=self.role) + async for output_attr in interpreter._role_start(self, **kwargs): + yield output_attr @dataclass class RoleEnd(ASTNode): role: str - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: - return interpreter._role_end(self, **kwargs) + async def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[NodeAttr]: + yield RoleCloserInput(name=self.role) + async for output_attr in interpreter._role_end(self, **kwargs): + yield output_attr @dataclass diff --git a/guidance/models/_base/_interpreter.py b/guidance/models/_base/_interpreter.py index 45059da10..8c1c0612c 100644 --- a/guidance/models/_base/_interpreter.py +++ b/guidance/models/_base/_interpreter.py @@ -26,16 +26,18 @@ CaptureEnd, ) from ..._utils import bytes_from -from ...trace import OutputAttr, TextOutput +from ...trace import InputAttr, OutputAttr, TextOutput from ._state import State +NodeAttr = Union[InputAttr, OutputAttr] + S = TypeVar("S", bound=State) class Interpreter(Generic[S], ABC): def __init__(self, state: S): self.state = state - async def run(self, node: ASTNode, **kwargs) -> AsyncIterable[OutputAttr]: + async def run(self, node: ASTNode, **kwargs) -> AsyncIterable[NodeAttr]: async for attr in node.simplify()._run(self, **kwargs): if isinstance(attr, TextOutput) and attr.is_generated: # TODO: this should probably be a lower-level responsibility? Not sure. @@ -51,26 +53,6 @@ async def capture_start(self, node: CaptureStart, **kwargs) -> AsyncIterable[Out async def capture_end(self, node: CaptureEnd, **kwargs) -> AsyncIterable[OutputAttr]: yield self.state.close_capture(node.name) - async def concatenate(self, node: Concatenate, **kwargs) -> AsyncIterable[OutputAttr]: - buffer: Optional[GrammarNode] = None - for child in node: - assert not isinstance(child, Concatenate) # iter should be flat - if isinstance(child, GrammarNode): - if buffer is None: - buffer = child - else: - buffer = buffer + child - else: - if buffer is not None: - async for attr in self.run(buffer, **kwargs): - yield attr - buffer = None - async for attr in self.run(child, **kwargs): - yield attr - if buffer is not None: - async for attr in self.run(buffer, **kwargs): - yield attr - def _role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: if self.state.active_role is not None: raise ValueError( diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index 07f288024..2e97457af 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -320,9 +320,9 @@ def _run_sync(self) -> None: return reentrant_await(self._run()) async def _run_node(self, node: ASTNode) -> None: - async for output_attr in self._interpreter.run(node): + async for node_attr in self._interpreter.run(node): self._increment_trace_id() - self._update_trace_node(self._id, self._parent_id, output_attr) + self._update_trace_node(self._id, self._parent_id, node_attr) # Stream current model state self._send_to_event_queue() From 14af49c53a590b387357671e60eb66b3751e15da Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 6 May 2025 15:50:27 -0700 Subject: [PATCH 55/56] fix mixin order with openai --- guidance/models/_openai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/guidance/models/_openai.py b/guidance/models/_openai.py index 7fd81b856..70cc4aefc 100644 --- a/guidance/models/_openai.py +++ b/guidance/models/_openai.py @@ -51,11 +51,11 @@ def __init__( if "audio-preview" in model: interpreter_cls = type( - "OpenAIAudioInterpreter", (OpenAIInterpreter, OpenAIAudioMixin), {} + "OpenAIAudioInterpreter", (OpenAIAudioMixin, OpenAIInterpreter), {} ) elif model.startswith("gpt-4o") or model.startswith("o1"): interpreter_cls = type( - "OpenAIImageInterpreter", (OpenAIInterpreter, OpenAIImageMixin), {} + "OpenAIImageInterpreter", (OpenAIImageMixin, OpenAIInterpreter), {} ) else: interpreter_cls = BaseOpenAIInterpreter From 62950b4144a8d3c616fe7c5a68c2a3f6873c67ee Mon Sep 17 00:00:00 2001 From: Hudson Cooper Date: Tue, 6 May 2025 15:52:34 -0700 Subject: [PATCH 56/56] fix openai audio gen --- guidance/models/_openai_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/guidance/models/_openai_base.py b/guidance/models/_openai_base.py index 8759c36d6..27cd14126 100644 --- a/guidance/models/_openai_base.py +++ b/guidance/models/_openai_base.py @@ -401,7 +401,7 @@ async def audio_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAtt ) yield AudioOutput(value=node.data, format=format, input=True) - async def gen_audio(self, node: GenAudio, **kwargs) -> AsyncIterable[OutputAttr]: + def gen_audio(self, node: GenAudio, **kwargs) -> AsyncIterable[OutputAttr]: return self._run( modalities=["text", "audio"], # Has to be both? audio={