Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/pdl/pdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ def exec_program(
if not isinstance(scope, PdlDict):
scope = PdlDict(scope or {})
loc = loc or empty_block_location
future_result, _, future_scope, trace = process_prog(state, scope, prog, loc)
initial_scope = {"pdl_model_default_parameters": get_default_model_parameters()}
future_result, _, future_scope, trace = process_prog(
state, scope | initial_scope, prog, loc
)
result = future_result.result()
match output:
case "result":
Expand Down
39 changes: 26 additions & 13 deletions src/pdl/pdl_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,15 +1278,6 @@ def process_call_model(
concrete_block, "parameters", scope, loc
)

# Apply PDL defaults to model invocation
if concrete_block.parameters is None or isinstance(
concrete_block.parameters, dict
):
concrete_block.parameters = apply_defaults(
str(model_id),
concrete_block.parameters or {},
scope.get("pdl_model_default_parameters", []),
)
case GraniteioModelBlock():
_, concrete_block = process_expr_of(concrete_block, "backend", scope, loc)
if concrete_block.processor is not None:
Expand Down Expand Up @@ -1339,7 +1330,9 @@ def get_transformed_inputs(kwargs):
if getenv("OTEL_EXPORTER") and getenv("OTEL_ENDPOINT"):
litellm.callbacks = ["otel"]

msg, raw_result = generate_client_response(state, concrete_block, model_input)
msg, raw_result = generate_client_response(
state, scope, concrete_block, str(model_id), model_input
)
background: LazyMessages = PdlList([lazy_apply(lambda msg: msg | {"defsite": block.pdl__id}, msg)]) # type: ignore
result = lazy_apply(
lambda msg: "" if msg["content"] is None else msg["content"], msg
Expand Down Expand Up @@ -1368,17 +1361,19 @@ def get_transformed_inputs(kwargs):

def generate_client_response(
state: InterpreterState,
scope: ScopeType,
block: LitellmModelBlock | GraniteioModelBlock,
model_id: str,
model_input: ModelInput,
) -> tuple[LazyMessage, PdlLazy[Any]]:
match state.batch:
case 0:
model_output, raw_result = generate_client_response_streaming(
state, block, model_input
state, scope, block, model_id, model_input
)
case 1:
model_output, raw_result = generate_client_response_single(
state, block, model_input
state, scope, block, model_id, model_input
)
case _:
assert False
Expand All @@ -1387,7 +1382,9 @@ def generate_client_response(

def generate_client_response_streaming(
state: InterpreterState,
scope: ScopeType,
block: LitellmModelBlock | GraniteioModelBlock,
model_id: str,
model_input: ModelInput,
) -> tuple[LazyMessage, PdlLazy[Any]]:
msg_stream: Generator[dict[str, Any], Any, Any]
Expand All @@ -1400,6 +1397,13 @@ def generate_client_response_streaming(
assert parameters is None or isinstance(
parameters, dict
) # block is a "concrete block"
# Apply PDL defaults to model invocation

parameters = apply_defaults(
model_id,
parameters or {},
scope.get("pdl_model_default_parameters", []),
)
msg_stream = LitellmModel.generate_text_stream(
model_id=value_of_expr(block.model),
messages=model_input,
Expand All @@ -1408,7 +1412,9 @@ def generate_client_response_streaming(
)
case GraniteioModelBlock():
# TODO: curently fallback to the non-streaming interface
return generate_client_response_single(state, block, model_input)
return generate_client_response_single(
state, scope, block, model_id, model_input
)
case _:
assert False
complete_msg: Optional[dict[str, Any]] = None
Expand Down Expand Up @@ -1465,7 +1471,9 @@ def litellm_parameters_to_dict(

def generate_client_response_single(
state: InterpreterState,
scope: ScopeType,
block: LitellmModelBlock | GraniteioModelBlock,
model_id: str,
model_input: ModelInput,
) -> tuple[LazyMessage, PdlLazy[Any]]:
if block.parameters is None:
Expand All @@ -1475,6 +1483,11 @@ def generate_client_response_single(
assert parameters is None or isinstance(
parameters, dict
) # block is a "concrete block"
parameters = apply_defaults(
model_id,
parameters or {},
scope.get("pdl_model_default_parameters", []),
)
block.pdl__usage = PdlUsage()
match block:
case LitellmModelBlock():
Expand Down
8 changes: 4 additions & 4 deletions tests/test_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ def test_code_var():
result = exec_dict(code_var_data, output="all")
text = result["result"]
scope = result["scope"]
assert scope == {
"pdl_context": [{"role": "user", "content": text, "defsite": "text.0.code"}],
"I": 0,
}
assert scope["pdl_context"] == [
{"role": "user", "content": text, "defsite": "text.0.code"}
]
assert scope["I"] == 0
assert text == "0"


Expand Down