Skip to content

Commit 4f7ad39

Browse files
authored
Bug fixes for setting default parameters (#838)
* Bug fixes for default parameter setting Signed-off-by: Mandana Vaziri <[email protected]>
1 parent 0c68f79 commit 4f7ad39

File tree

3 files changed

+34
-18
lines changed

3 files changed

+34
-18
lines changed

src/pdl/pdl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ def exec_program(
7070
if not isinstance(scope, PdlDict):
7171
scope = PdlDict(scope or {})
7272
loc = loc or empty_block_location
73-
future_result, _, future_scope, trace = process_prog(state, scope, prog, loc)
73+
initial_scope = {"pdl_model_default_parameters": get_default_model_parameters()}
74+
future_result, _, future_scope, trace = process_prog(
75+
state, scope | initial_scope, prog, loc
76+
)
7477
result = future_result.result()
7578
match output:
7679
case "result":

src/pdl/pdl_interpreter.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,15 +1278,6 @@ def process_call_model(
12781278
concrete_block, "parameters", scope, loc
12791279
)
12801280

1281-
# Apply PDL defaults to model invocation
1282-
if concrete_block.parameters is None or isinstance(
1283-
concrete_block.parameters, dict
1284-
):
1285-
concrete_block.parameters = apply_defaults(
1286-
str(model_id),
1287-
concrete_block.parameters or {},
1288-
scope.get("pdl_model_default_parameters", []),
1289-
)
12901281
case GraniteioModelBlock():
12911282
_, concrete_block = process_expr_of(concrete_block, "backend", scope, loc)
12921283
if concrete_block.processor is not None:
@@ -1339,7 +1330,9 @@ def get_transformed_inputs(kwargs):
13391330
if getenv("OTEL_EXPORTER") and getenv("OTEL_ENDPOINT"):
13401331
litellm.callbacks = ["otel"]
13411332

1342-
msg, raw_result = generate_client_response(state, concrete_block, model_input)
1333+
msg, raw_result = generate_client_response(
1334+
state, scope, concrete_block, str(model_id), model_input
1335+
)
13431336
background: LazyMessages = PdlList([lazy_apply(lambda msg: msg | {"defsite": block.pdl__id}, msg)]) # type: ignore
13441337
result = lazy_apply(
13451338
lambda msg: "" if msg["content"] is None else msg["content"], msg
@@ -1368,17 +1361,19 @@ def get_transformed_inputs(kwargs):
13681361

13691362
def generate_client_response(
13701363
state: InterpreterState,
1364+
scope: ScopeType,
13711365
block: LitellmModelBlock | GraniteioModelBlock,
1366+
model_id: str,
13721367
model_input: ModelInput,
13731368
) -> tuple[LazyMessage, PdlLazy[Any]]:
13741369
match state.batch:
13751370
case 0:
13761371
model_output, raw_result = generate_client_response_streaming(
1377-
state, block, model_input
1372+
state, scope, block, model_id, model_input
13781373
)
13791374
case 1:
13801375
model_output, raw_result = generate_client_response_single(
1381-
state, block, model_input
1376+
state, scope, block, model_id, model_input
13821377
)
13831378
case _:
13841379
assert False
@@ -1387,7 +1382,9 @@ def generate_client_response(
13871382

13881383
def generate_client_response_streaming(
13891384
state: InterpreterState,
1385+
scope: ScopeType,
13901386
block: LitellmModelBlock | GraniteioModelBlock,
1387+
model_id: str,
13911388
model_input: ModelInput,
13921389
) -> tuple[LazyMessage, PdlLazy[Any]]:
13931390
msg_stream: Generator[dict[str, Any], Any, Any]
@@ -1400,6 +1397,13 @@ def generate_client_response_streaming(
14001397
assert parameters is None or isinstance(
14011398
parameters, dict
14021399
) # block is a "concrete block"
1400+
# Apply PDL defaults to model invocation
1401+
1402+
parameters = apply_defaults(
1403+
model_id,
1404+
parameters or {},
1405+
scope.get("pdl_model_default_parameters", []),
1406+
)
14031407
msg_stream = LitellmModel.generate_text_stream(
14041408
model_id=value_of_expr(block.model),
14051409
messages=model_input,
@@ -1408,7 +1412,9 @@ def generate_client_response_streaming(
14081412
)
14091413
case GraniteioModelBlock():
14101414
# TODO: curently fallback to the non-streaming interface
1411-
return generate_client_response_single(state, block, model_input)
1415+
return generate_client_response_single(
1416+
state, scope, block, model_id, model_input
1417+
)
14121418
case _:
14131419
assert False
14141420
complete_msg: Optional[dict[str, Any]] = None
@@ -1465,7 +1471,9 @@ def litellm_parameters_to_dict(
14651471

14661472
def generate_client_response_single(
14671473
state: InterpreterState,
1474+
scope: ScopeType,
14681475
block: LitellmModelBlock | GraniteioModelBlock,
1476+
model_id: str,
14691477
model_input: ModelInput,
14701478
) -> tuple[LazyMessage, PdlLazy[Any]]:
14711479
if block.parameters is None:
@@ -1475,6 +1483,11 @@ def generate_client_response_single(
14751483
assert parameters is None or isinstance(
14761484
parameters, dict
14771485
) # block is a "concrete block"
1486+
parameters = apply_defaults(
1487+
model_id,
1488+
parameters or {},
1489+
scope.get("pdl_model_default_parameters", []),
1490+
)
14781491
block.pdl__usage = PdlUsage()
14791492
match block:
14801493
case LitellmModelBlock():

tests/test_var.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ def test_code_var():
112112
result = exec_dict(code_var_data, output="all")
113113
text = result["result"]
114114
scope = result["scope"]
115-
assert scope == {
116-
"pdl_context": [{"role": "user", "content": text, "defsite": "text.0.code"}],
117-
"I": 0,
118-
}
115+
assert scope["pdl_context"] == [
116+
{"role": "user", "content": text, "defsite": "text.0.code"}
117+
]
118+
assert scope["I"] == 0
119119
assert text == "0"
120120

121121

0 commit comments

Comments
 (0)