Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions src/guardrails/_openai_utils.py

This file was deleted.

5 changes: 2 additions & 3 deletions src/guardrails/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from pathlib import Path
from typing import Any

from ._openai_utils import prepare_openai_kwargs
from .utils.conversation import merge_conversation_with_items, normalize_conversation

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -167,7 +166,7 @@ def _create_default_tool_context() -> Any:
class DefaultContext:
guardrail_llm: AsyncOpenAI

return DefaultContext(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({})))
return DefaultContext(guardrail_llm=AsyncOpenAI())


def _create_conversation_context(
Expand Down Expand Up @@ -393,7 +392,7 @@ def _create_agents_guardrails_from_config(
class DefaultContext:
guardrail_llm: AsyncOpenAI

context = DefaultContext(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({})))
context = DefaultContext(guardrail_llm=AsyncOpenAI())

def _create_stage_guardrail(stage_name: str):
async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput:
Expand Down
24 changes: 16 additions & 8 deletions src/guardrails/checks/text/llm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class MyLLMOutput(LLMOutput):
from guardrails.types import CheckFn, GuardrailLLMContextProto, GuardrailResult
from guardrails.utils.output import OutputSchema

from ...utils.safety_identifier import SAFETY_IDENTIFIER, supports_safety_identifier

if TYPE_CHECKING:
from openai import AsyncAzureOpenAI, AzureOpenAI # type: ignore[unused-import]
else:
Expand All @@ -62,10 +64,10 @@ class MyLLMOutput(LLMOutput):

__all__ = [
"LLMConfig",
"LLMOutput",
"LLMErrorOutput",
"create_llm_check_fn",
"LLMOutput",
"create_error_result",
"create_llm_check_fn",
]
Comment on lines 65 to 71
Copy link

Copilot AI Oct 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The __all__ list was reordered alphabetically, but this appears to be an incidental change unrelated to the PR's main purpose of refactoring the safety identifier. The previous ordering grouped related items together ('LLMConfig', 'LLMOutput', 'LLMErrorOutput', then functions). Consider reverting this change to keep the PR focused on its core objective.

Copilot uses AI. Check for mistakes.


Expand Down Expand Up @@ -247,12 +249,18 @@ async def _request_chat_completion(
response_format: dict[str, Any],
) -> Any:
"""Invoke chat.completions.create on sync or async OpenAI clients."""
return await _invoke_openai_callable(
client.chat.completions.create,
messages=messages,
model=model,
response_format=response_format,
)
# Only include safety_identifier for official OpenAI API
kwargs: dict[str, Any] = {
"messages": messages,
"model": model,
"response_format": response_format,
}

# Only official OpenAI API supports safety_identifier (not Azure or local models)
if supports_safety_identifier(client):
kwargs["safety_identifier"] = SAFETY_IDENTIFIER

return await _invoke_openai_callable(client.chat.completions.create, **kwargs)


async def run_llm(
Expand Down
4 changes: 1 addition & 3 deletions src/guardrails/checks/text/moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@
from guardrails.spec import GuardrailSpecMetadata
from guardrails.types import GuardrailResult

from ..._openai_utils import prepare_openai_kwargs

logger = logging.getLogger(__name__)

__all__ = ["moderation", "Category", "ModerationCfg"]
Expand Down Expand Up @@ -129,7 +127,7 @@ def _get_moderation_client() -> AsyncOpenAI:
Returns:
AsyncOpenAI: Cached OpenAI API client for moderation checks.
"""
return AsyncOpenAI(**prepare_openai_kwargs({}))
return AsyncOpenAI()


async def moderation(
Expand Down
9 changes: 2 additions & 7 deletions src/guardrails/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
GuardrailsResponse,
OpenAIResponseType,
)
from ._openai_utils import prepare_openai_kwargs
from ._streaming import StreamingMixin
from .exceptions import GuardrailTripwireTriggered
from .runtime import run_guardrails
Expand Down Expand Up @@ -167,7 +166,6 @@ def __init__(
by this parameter.
**openai_kwargs: Additional arguments passed to AsyncOpenAI constructor.
"""
openai_kwargs = prepare_openai_kwargs(openai_kwargs)
# Initialize OpenAI client first
super().__init__(**openai_kwargs)

Expand Down Expand Up @@ -205,7 +203,7 @@ class DefaultContext:
default_headers = getattr(self, "default_headers", None)
if default_headers is not None:
guardrail_kwargs["default_headers"] = default_headers
guardrail_client = AsyncOpenAI(**prepare_openai_kwargs(guardrail_kwargs))
guardrail_client = AsyncOpenAI(**guardrail_kwargs)

return DefaultContext(guardrail_llm=guardrail_client)

Expand Down Expand Up @@ -335,7 +333,6 @@ def __init__(
by this parameter.
**openai_kwargs: Additional arguments passed to OpenAI constructor.
"""
openai_kwargs = prepare_openai_kwargs(openai_kwargs)
# Initialize OpenAI client first
super().__init__(**openai_kwargs)

Expand Down Expand Up @@ -373,7 +370,7 @@ class DefaultContext:
default_headers = getattr(self, "default_headers", None)
if default_headers is not None:
guardrail_kwargs["default_headers"] = default_headers
guardrail_client = OpenAI(**prepare_openai_kwargs(guardrail_kwargs))
guardrail_client = OpenAI(**guardrail_kwargs)

return DefaultContext(guardrail_llm=guardrail_client)

Expand Down Expand Up @@ -516,7 +513,6 @@ def __init__(
by this parameter.
**azure_kwargs: Additional arguments passed to AsyncAzureOpenAI constructor.
"""
azure_kwargs = prepare_openai_kwargs(azure_kwargs)
# Initialize Azure client first
super().__init__(**azure_kwargs)

Expand Down Expand Up @@ -671,7 +667,6 @@ def __init__(
by this parameter.
**azure_kwargs: Additional arguments passed to AzureOpenAI constructor.
"""
azure_kwargs = prepare_openai_kwargs(azure_kwargs)
super().__init__(**azure_kwargs)

# Store the error handling preference
Expand Down
5 changes: 2 additions & 3 deletions src/guardrails/evals/guardrail_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@


from guardrails import instantiate_guardrails, load_pipeline_bundles
from guardrails._openai_utils import prepare_openai_kwargs
from guardrails.evals.core import (
AsyncRunEngine,
BenchmarkMetricsCalculator,
Expand Down Expand Up @@ -281,7 +280,7 @@ def _create_context(self) -> Context:
if self.api_key:
azure_kwargs["api_key"] = self.api_key

guardrail_llm = AsyncAzureOpenAI(**prepare_openai_kwargs(azure_kwargs))
guardrail_llm = AsyncAzureOpenAI(**azure_kwargs)
logger.info("Created Azure OpenAI client for endpoint: %s", self.azure_endpoint)
# OpenAI or OpenAI-compatible API
else:
Expand All @@ -292,7 +291,7 @@ def _create_context(self) -> Context:
openai_kwargs["base_url"] = self.base_url
logger.info("Created OpenAI-compatible client for base_url: %s", self.base_url)

guardrail_llm = AsyncOpenAI(**prepare_openai_kwargs(openai_kwargs))
guardrail_llm = AsyncOpenAI(**openai_kwargs)

return Context(guardrail_llm=guardrail_llm)

Expand Down
31 changes: 22 additions & 9 deletions src/guardrails/resources/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any

from ..._base_client import GuardrailsBaseClient
from ...utils.safety_identifier import SAFETY_IDENTIFIER, supports_safety_identifier


class Chat:
Expand Down Expand Up @@ -82,12 +83,19 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals

# Run input guardrails and LLM call concurrently using a thread for the LLM
with ThreadPoolExecutor(max_workers=1) as executor:
# Only include safety_identifier for OpenAI clients (not Azure)
llm_kwargs = {
"messages": modified_messages,
"model": model,
"stream": stream,
**kwargs,
}
if supports_safety_identifier(self._client._resource_client):
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER

llm_future = executor.submit(
self._client._resource_client.chat.completions.create,
messages=modified_messages, # Use messages with any preflight modifications
model=model,
stream=stream,
**kwargs,
**llm_kwargs,
)
input_results = self._client._run_stage_guardrails(
"input",
Expand Down Expand Up @@ -152,12 +160,17 @@ async def create(
conversation_history=normalized_conversation,
suppress_tripwire=suppress_tripwire,
)
llm_call = self._client._resource_client.chat.completions.create(
messages=modified_messages, # Use messages with any preflight modifications
model=model,
stream=stream,
# Only include safety_identifier for OpenAI clients (not Azure)
llm_kwargs = {
"messages": modified_messages,
"model": model,
"stream": stream,
**kwargs,
)
}
if supports_safety_identifier(self._client._resource_client):
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER

llm_call = self._client._resource_client.chat.completions.create(**llm_kwargs)

input_results, llm_response = await asyncio.gather(input_check, llm_call)

Expand Down
69 changes: 48 additions & 21 deletions src/guardrails/resources/responses/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pydantic import BaseModel

from ..._base_client import GuardrailsBaseClient
from ...utils.safety_identifier import SAFETY_IDENTIFIER, supports_safety_identifier


class Responses:
Expand Down Expand Up @@ -63,13 +64,20 @@ def create(

# Input guardrails and LLM call concurrently
with ThreadPoolExecutor(max_workers=1) as executor:
# Only include safety_identifier for OpenAI clients (not Azure or local models)
llm_kwargs = {
"input": modified_input,
"model": model,
"stream": stream,
"tools": tools,
**kwargs,
}
if supports_safety_identifier(self._client._resource_client):
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER

llm_future = executor.submit(
self._client._resource_client.responses.create,
input=modified_input, # Use preflight-modified input
model=model,
stream=stream,
tools=tools,
**kwargs,
**llm_kwargs,
)
input_results = self._client._run_stage_guardrails(
"input",
Expand Down Expand Up @@ -123,12 +131,19 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM

# Input guardrails and LLM call concurrently
with ThreadPoolExecutor(max_workers=1) as executor:
# Only include safety_identifier for OpenAI clients (not Azure or local models)
llm_kwargs = {
"input": modified_input,
"model": model,
"text_format": text_format,
**kwargs,
}
if supports_safety_identifier(self._client._resource_client):
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER

llm_future = executor.submit(
self._client._resource_client.responses.parse,
input=modified_input, # Use modified input with preflight changes
model=model,
text_format=text_format,
**kwargs,
**llm_kwargs,
)
input_results = self._client._run_stage_guardrails(
"input",
Expand Down Expand Up @@ -218,13 +233,19 @@ async def create(
conversation_history=normalized_conversation,
suppress_tripwire=suppress_tripwire,
)
llm_call = self._client._resource_client.responses.create(
input=modified_input, # Use preflight-modified input
model=model,
stream=stream,
tools=tools,

# Only include safety_identifier for OpenAI clients (not Azure or local models)
llm_kwargs = {
"input": modified_input,
"model": model,
"stream": stream,
"tools": tools,
**kwargs,
)
}
if supports_safety_identifier(self._client._resource_client):
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER

llm_call = self._client._resource_client.responses.create(**llm_kwargs)

input_results, llm_response = await asyncio.gather(input_check, llm_call)

Expand Down Expand Up @@ -278,13 +299,19 @@ async def parse(
conversation_history=normalized_conversation,
suppress_tripwire=suppress_tripwire,
)
llm_call = self._client._resource_client.responses.parse(
input=modified_input, # Use modified input with preflight changes
model=model,
text_format=text_format,
stream=stream,

# Only include safety_identifier for OpenAI clients (not Azure or local models)
llm_kwargs = {
"input": modified_input,
"model": model,
"text_format": text_format,
"stream": stream,
**kwargs,
)
}
if supports_safety_identifier(self._client._resource_client):
llm_kwargs["safety_identifier"] = SAFETY_IDENTIFIER

llm_call = self._client._resource_client.responses.parse(**llm_kwargs)

input_results, llm_response = await asyncio.gather(input_check, llm_call)

Expand Down
3 changes: 1 addition & 2 deletions src/guardrails/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from openai import AsyncOpenAI
from pydantic import BaseModel, ConfigDict

from ._openai_utils import prepare_openai_kwargs
from .exceptions import ConfigError, GuardrailTripwireTriggered
from .registry import GuardrailRegistry, default_spec_registry
from .spec import GuardrailSpec
Expand Down Expand Up @@ -495,7 +494,7 @@ def _get_default_ctx():
class DefaultCtx:
guardrail_llm: AsyncOpenAI

return DefaultCtx(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({})))
return DefaultCtx(guardrail_llm=AsyncOpenAI())


async def check_plain_text(
Expand Down
Loading
Loading