Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion guidance/library/_capture.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .._guidance import guidance
from .._grammar import capture as grammar_capture, GrammarFunction

# Adapted from active_role_end in _model.py, functionality should be shared probably?
import re
format_pattern = re.compile(r"<\|\|_.*?_\|\|>", flags=re.DOTALL)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slightly concerned that this appears to be relying on ChatML tags, which not all models use


@guidance(stateless=lambda *args, **kwargs: isinstance(args[0], GrammarFunction))
def capture(lm, value, name):
Expand All @@ -9,4 +12,11 @@ def capture(lm, value, name):
else:
start_len = len(lm)
lm += value
return lm.set(name, str(lm)[start_len:])
# Adapted from active_role_end in _model.py
parts = ""
for _, role_end_str in lm.opened_blocks.values():
role_end_str = format_pattern.sub("", role_end_str)
if len(role_end_str) > 0 and not re.fullmatch(r"\s+", role_end_str):
parts += role_end_str

return lm.set(name, str(lm)[start_len-len(parts):].removesuffix(parts))
21 changes: 15 additions & 6 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@

AVAILABLE_MODELS = {
"gpt2cpu": dict(name="transformers:gpt2", kwargs=dict()),
"phi2cpu": dict(
name="transformers:microsoft/phi-2", kwargs={"trust_remote_code": True}
),
"phi2cpu": dict(name="transformers:microsoft/phi-2", kwargs={"trust_remote_code": True}),
"azure_guidance": dict(
name="azure_guidance:",
kwargs={},
Expand All @@ -41,9 +39,7 @@
name="huggingface_hubllama:TheBloke/Llama-2-7B-GGUF:llama-2-7b.Q5_K_M.gguf",
kwargs={"verbose": True, "n_ctx": 4096},
),
"transformers_mistral_7b": dict(
name="transformers:mistralai/Mistral-7B-v0.1", kwargs=dict()
),
"transformers_mistral_7b": dict(name="transformers:mistralai/Mistral-7B-v0.1", kwargs=dict()),
"hfllama_mistral_7b": dict(
name="huggingface_hubllama:TheBloke/Mistral-7B-Instruct-v0.2-GGUF:mistral-7b-instruct-v0.2.Q8_0.gguf",
kwargs={"verbose": True},
Expand Down Expand Up @@ -101,6 +97,19 @@ def selected_model(selected_model_info: str) -> models.Model:
return model


@pytest.fixture(scope="module")
def model_with_role_tags(selected_model, selected_model_name):
if selected_model_name in [
"transformers_phi3cpu_mini_4k_instruct",
"transformers_llama3cpu_8b",
"hfllama_phi3cpu_mini_4k_instruct",
"hfllama_mistral_7b",
]:
return selected_model
else:
pytest.skip("Requires a model that supports role tags!")


@pytest.fixture(scope="function")
def rate_limiter() -> int:
"""Limit test execution rate
Expand Down
8 changes: 8 additions & 0 deletions tests/model_integration/library/test_capture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import guidance

def test_capture_within_role(model_with_role_tags: guidance.models.Model):
lm = model_with_role_tags
test_text = "This is some text in a role."
with guidance.user():
lm += guidance.capture(test_text, "test")
assert lm["test"] == test_text
17 changes: 12 additions & 5 deletions tests/unit/library/test_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,18 @@ def raw_fn(lm):
elif lm["state"] == "2":
lm += select(["5", "6"], name="state_2")
return lm

lm_nocap = lm + "the beginning|" + raw_fn() + "|the end"
lm_cap_arg = lm + "the beginning|" + capture("<cap>" + raw_fn() + "</cap>" , "cap_arg") + "|the end"
lm_cap_kwarg = lm + "the beginning|" + capture("<cap>" + raw_fn() + "</cap>", name="cap_kwarg") + "|the end"

lm_cap_arg = (
lm + "the beginning|" + capture("<cap>" + raw_fn() + "</cap>", "cap_arg") + "|the end"
)
lm_cap_kwarg = (
lm
+ "the beginning|"
+ capture("<cap>" + raw_fn() + "</cap>", name="cap_kwarg")
+ "|the end"
)

# Bunch of random tests
assert "state_1" in lm_nocap or "state_2" in lm_nocap
assert "cap_arg" in lm_cap_arg
Expand All @@ -42,4 +49,4 @@ def raw_fn(lm):

assert str(lm_nocap).endswith("|the end")
assert str(lm_cap_arg).endswith("|the end")
assert str(lm_cap_kwarg).endswith("|the end")
assert str(lm_cap_kwarg).endswith("|the end")