1717warnings .filterwarnings ("ignore" , "Valid config keys have changed in V2" )
1818
1919from pathlib import Path # noqa: E402
20- from typing import Any , Generator , Optional , Sequence , TypeVar # noqa: E402
20+ from typing import Any , Generator , Generic , Optional , Sequence , TypeVar # noqa: E402
2121
2222import httpx # noqa: E402
2323import json_repair # noqa: E402
3333from jinja2 .nodes import TemplateData # noqa: E402
3434from jinja2 .runtime import Undefined # noqa: E402
3535from pydantic import BaseModel , ConfigDict , Field # noqa: E402
36+ from pydantic .json_schema import SkipJsonSchema # noqa: E402
3637
3738from .pdl_ast import ( # noqa: E402
3839 AdvancedBlockType ,
131132empty_scope : ScopeType = PdlDict ({"pdl_context" : DependentContext ([])})
132133
133134
135+ RefT = TypeVar ("RefT" )
136+
137+
138+ class Ref (Generic [RefT ]):
139+ def __init__ (self , ref : RefT ):
140+ self .ref = ref
141+
142+
134143class InterpreterState (BaseModel ):
135144 model_config = ConfigDict (arbitrary_types_allowed = True )
136145
137146 yield_result : bool = False
147+ """Stream the result on the standard output as soon as possible."""
138148 yield_background : bool = False
149+ """Stream the toplevel pdl_context on the standard output as soon as possible."""
139150 batch : int = 1
140- # batch=0: streaming
141- # batch=1: call to generate with `input`
151+ """
152+ Stream the output of the LLM
153+ - batch=0: streaming
154+ - batch=1: call to generate with `input`
155+ """
142156 role : RoleType = "user"
157+ """Current role to add messages in the context."""
143158 cwd : Path = Path .cwd ()
144- # background_tasks = {}
159+ """Current working directory."""
145160 id_stack : list [str ] = []
161+ """Id generator for the UI."""
162+
163+ # The following are shared variable that should be modified by side effects
146164 event_loop : AbstractEventLoop = Field (default_factory = create_event_loop_thread )
165+ """Event loop to schedule LLM calls."""
166+ current_pdl_context : Ref [LazyMessages ] = Ref (DependentContext ([]))
167+ """Current value of the context set at the beginning of the execution of the block."""
147168
148169 def with_yield_result (self : "InterpreterState" , b : bool ) -> "InterpreterState" :
149170 return self .model_copy (update = {"yield_result" : b })
@@ -168,6 +189,19 @@ def with_pop(self: "InterpreterState") -> "InterpreterState":
168189 return self .model_copy (update = {"id_stack" : stack })
169190
170191
192+ class ClosureBlock (FunctionBlock ):
193+ pdl__scope : SkipJsonSchema [Optional [ScopeType ]] = Field (repr = False )
194+ pdl__state : SkipJsonSchema [InterpreterState ] = Field (repr = False )
195+
196+ def __call__ (self , ** kwds ):
197+ state = self .pdl__state .with_yield_result (False ).with_yield_background (False )
198+ current_context = state .current_pdl_context .ref
199+ result , _ , _ = execute_call (
200+ state , current_context , self , kwds , empty_block_location
201+ )
202+ return result
203+
204+
171205def generate (
172206 pdl_file : str | Path ,
173207 state : Optional [InterpreterState ],
@@ -246,6 +280,7 @@ def process_block(
246280 background : LazyMessages
247281 trace : BlockType
248282 try :
283+ state .current_pdl_context .ref = scope ["pdl_context" ] # type: ignore
249284 if not isinstance (block , Block ):
250285 start = time .time_ns ()
251286 try :
@@ -436,7 +471,7 @@ def process_advanced_block(
436471 result .result ()
437472 break
438473 except Exception as exc :
439- err_msg = exc . args [ 0 ]
474+ err_msg = traceback . format_exc ()
440475 do_retry = (
441476 block .retry
442477 and trial_idx + 1 < trial_total
@@ -915,7 +950,23 @@ def process_block_body(
915950 result , background , scope , trace = process_import (state , scope , block , loc )
916951
917952 case FunctionBlock ():
918- closure = block .model_copy ()
953+ closure = ClosureBlock ( # pyright: ignore
954+ description = block .description ,
955+ spec = block .spec ,
956+ defs = block .defs ,
957+ def_ = block .def_ , # pyright: ignore
958+ contribute = block .contribute ,
959+ parser = block .parser ,
960+ fallback = block .fallback ,
961+ retry = block .retry ,
962+ trace_error_on_retry = block .trace_error_on_retry ,
963+ role = block .role ,
964+ function = block .function ,
965+ return_ = block .return_ , # pyright: ignore
966+ pdl__location = loc ,
967+ pdl__scope = None ,
968+ pdl__state = state ,
969+ )
919970 if block .def_ is not None :
920971 scope = scope | {block .def_ : closure }
921972 closure .pdl__scope = scope
@@ -1872,7 +1923,7 @@ def process_call(
18721923 background : LazyMessages = DependentContext ([])
18731924 args , block = process_expr_of (block , "args" , scope , loc )
18741925 closure , _ = process_expr_of (block , "call" , scope , loc )
1875- if not isinstance (closure , FunctionBlock ):
1926+ if not isinstance (closure , ClosureBlock ):
18761927 msg = f"Type error: { block .call } is of type { type (closure )} but should be a function."
18771928 if isinstance (closure , str ) and isinstance (scope .get (closure ), FunctionBlock ):
18781929 msg += " You might want to call `${ " + str (block .call ) + " }`."
@@ -1890,12 +1941,28 @@ def process_call(
18901941 loc = args_loc ,
18911942 trace = block .model_copy (),
18921943 )
1944+ current_context = scope .data ["pdl_context" ]
1945+ try :
1946+ result , background , call_trace = execute_call (
1947+ state , current_context , closure , args , loc
1948+ )
1949+ except PDLRuntimeError as exc :
1950+ raise PDLRuntimeError (
1951+ exc .message ,
1952+ loc = exc .loc or closure .pdl__location ,
1953+ trace = block .model_copy (update = {"pdl__trace" : exc .pdl__trace }),
1954+ ) from exc
1955+ trace = block .model_copy (update = {"pdl__trace" : call_trace })
1956+ return result , background , scope , trace
1957+
1958+
1959+ def execute_call (state , current_context , closure , args , loc ):
18931960 if "pdl_context" in args :
1894- args [ "pdl_context" ] = deserialize (args ["pdl_context" ])
1961+ args = args | { "pdl_context" : deserialize (args ["pdl_context" ])}
18951962 f_body = closure .return_
18961963 f_scope = (
18971964 (closure .pdl__scope or PdlDict ({}))
1898- | PdlDict ({"pdl_context" : scope . data [ "pdl_context" ] })
1965+ | PdlDict ({"pdl_context" : current_context })
18991966 | PdlDict ((args or {}))
19001967 )
19011968 if closure .pdl__location is not None :
@@ -1906,27 +1973,19 @@ def process_call(
19061973 )
19071974 else :
19081975 fun_loc = empty_block_location
1909- try :
1910- result , background , _ , f_trace = process_block (state , f_scope , f_body , fun_loc )
1911- except PDLRuntimeError as exc :
1912- raise PDLRuntimeError (
1913- exc .message ,
1914- loc = exc .loc or fun_loc ,
1915- trace = block .model_copy (update = {"pdl__trace" : exc .pdl__trace }),
1916- ) from exc
1917- trace = block .model_copy (update = {"pdl__trace" : f_trace })
1976+ result , background , _ , f_trace = process_block (state , f_scope , f_body , fun_loc )
19181977 if closure .spec is not None :
19191978 result = lazy_apply (
19201979 lambda r : result_with_type_checking (
19211980 r ,
19221981 closure .spec ,
1923- f"Type errors in result of function call to { block . call } :" ,
1924- loc ,
1925- trace ,
1982+ f"Type errors in result of the function{ ' ' + closure . signature . get ( 'name' , '' ) if closure . signature is not None else '' } :" ,
1983+ fun_loc ,
1984+ f_trace ,
19261985 ),
19271986 result ,
19281987 )
1929- return result , background , scope , trace
1988+ return result , background , f_trace
19301989
19311990
19321991def process_input (
0 commit comments