Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions src/guardrails/_openai_utils.py

This file was deleted.

5 changes: 2 additions & 3 deletions src/guardrails/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from pathlib import Path
from typing import Any

from ._openai_utils import prepare_openai_kwargs
from .utils.conversation import merge_conversation_with_items, normalize_conversation

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -167,7 +166,7 @@ def _create_default_tool_context() -> Any:
class DefaultContext:
guardrail_llm: AsyncOpenAI

return DefaultContext(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({})))
return DefaultContext(guardrail_llm=AsyncOpenAI())


def _create_conversation_context(
Expand Down Expand Up @@ -393,7 +392,7 @@ def _create_agents_guardrails_from_config(
class DefaultContext:
guardrail_llm: AsyncOpenAI

context = DefaultContext(guardrail_llm=AsyncOpenAI(**prepare_openai_kwargs({})))
context = DefaultContext(guardrail_llm=AsyncOpenAI())

def _create_stage_guardrail(stage_name: str):
async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput:
Expand Down
49 changes: 43 additions & 6 deletions src/guardrails/checks/text/llm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,37 @@ class MyLLMOutput(LLMOutput):
from guardrails.types import CheckFn, GuardrailLLMContextProto, GuardrailResult
from guardrails.utils.output import OutputSchema

# OpenAI safety identifier for tracking guardrails library usage
# Only supported by official OpenAI API (not Azure or local/alternative providers)
_SAFETY_IDENTIFIER = "oai_guardrails"


def _supports_safety_identifier(client: AsyncOpenAI | OpenAI | AsyncAzureOpenAI | AzureOpenAI) -> bool:
"""Check if the client supports the safety_identifier parameter.

Only the official OpenAI API supports this parameter.
Azure OpenAI and local/alternative providers do not.

Args:
client: The OpenAI client instance.

Returns:
True if safety_identifier should be included, False otherwise.
"""
# Azure clients don't support it
if isinstance(client, AsyncAzureOpenAI | AzureOpenAI):
return False

# Check if using a custom base_url (local or alternative provider)
base_url = getattr(client, "base_url", None)
if base_url is not None:
base_url_str = str(base_url)
# Only official OpenAI API endpoints support safety_identifier
return "api.openai.com" in base_url_str

# Default OpenAI client (no custom base_url) supports it
return True

if TYPE_CHECKING:
from openai import AsyncAzureOpenAI, AzureOpenAI # type: ignore[unused-import]
else:
Expand Down Expand Up @@ -247,12 +278,18 @@ async def _request_chat_completion(
response_format: dict[str, Any],
) -> Any:
"""Invoke chat.completions.create on sync or async OpenAI clients."""
return await _invoke_openai_callable(
client.chat.completions.create,
messages=messages,
model=model,
response_format=response_format,
)
# Only include safety_identifier for official OpenAI API
kwargs: dict[str, Any] = {
"messages": messages,
"model": model,
"response_format": response_format,
}

# Only official OpenAI API supports safety_identifier (not Azure or local models)
if _supports_safety_identifier(client):
kwargs["safety_identifier"] = _SAFETY_IDENTIFIER

return await _invoke_openai_callable(client.chat.completions.create, **kwargs)


async def run_llm(
Expand Down
4 changes: 1 addition & 3 deletions src/guardrails/checks/text/moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@
from guardrails.spec import GuardrailSpecMetadata
from guardrails.types import GuardrailResult

from ..._openai_utils import prepare_openai_kwargs

logger = logging.getLogger(__name__)

__all__ = ["moderation", "Category", "ModerationCfg"]
Expand Down Expand Up @@ -129,7 +127,7 @@ def _get_moderation_client() -> AsyncOpenAI:
Returns:
AsyncOpenAI: Cached OpenAI API client for moderation checks.
"""
return AsyncOpenAI(**prepare_openai_kwargs({}))
return AsyncOpenAI()


async def moderation(
Expand Down
9 changes: 2 additions & 7 deletions src/guardrails/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
GuardrailsResponse,
OpenAIResponseType,
)
from ._openai_utils import prepare_openai_kwargs
from ._streaming import StreamingMixin
from .exceptions import GuardrailTripwireTriggered
from .runtime import run_guardrails
Expand Down Expand Up @@ -167,7 +166,6 @@ def __init__(
by this parameter.
**openai_kwargs: Additional arguments passed to AsyncOpenAI constructor.
"""
openai_kwargs = prepare_openai_kwargs(openai_kwargs)
# Initialize OpenAI client first
super().__init__(**openai_kwargs)

Expand Down Expand Up @@ -205,7 +203,7 @@ class DefaultContext:
default_headers = getattr(self, "default_headers", None)
if default_headers is not None:
guardrail_kwargs["default_headers"] = default_headers
guardrail_client = AsyncOpenAI(**prepare_openai_kwargs(guardrail_kwargs))
guardrail_client = AsyncOpenAI(**guardrail_kwargs)

return DefaultContext(guardrail_llm=guardrail_client)

Expand Down Expand Up @@ -335,7 +333,6 @@ def __init__(
by this parameter.
**openai_kwargs: Additional arguments passed to OpenAI constructor.
"""
openai_kwargs = prepare_openai_kwargs(openai_kwargs)
# Initialize OpenAI client first
super().__init__(**openai_kwargs)

Expand Down Expand Up @@ -373,7 +370,7 @@ class DefaultContext:
default_headers = getattr(self, "default_headers", None)
if default_headers is not None:
guardrail_kwargs["default_headers"] = default_headers
guardrail_client = OpenAI(**prepare_openai_kwargs(guardrail_kwargs))
guardrail_client = OpenAI(**guardrail_kwargs)

return DefaultContext(guardrail_llm=guardrail_client)

Expand Down Expand Up @@ -516,7 +513,6 @@ def __init__(
by this parameter.
**azure_kwargs: Additional arguments passed to AsyncAzureOpenAI constructor.
"""
azure_kwargs = prepare_openai_kwargs(azure_kwargs)
# Initialize Azure client first
super().__init__(**azure_kwargs)

Expand Down Expand Up @@ -671,7 +667,6 @@ def __init__(
by this parameter.
**azure_kwargs: Additional arguments passed to AzureOpenAI constructor.
"""
azure_kwargs = prepare_openai_kwargs(azure_kwargs)
super().__init__(**azure_kwargs)

# Store the error handling preference
Expand Down
5 changes: 2 additions & 3 deletions src/guardrails/evals/guardrail_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@


from guardrails import instantiate_guardrails, load_pipeline_bundles
from guardrails._openai_utils import prepare_openai_kwargs
from guardrails.evals.core import (
AsyncRunEngine,
BenchmarkMetricsCalculator,
Expand Down Expand Up @@ -281,7 +280,7 @@ def _create_context(self) -> Context:
if self.api_key:
azure_kwargs["api_key"] = self.api_key

guardrail_llm = AsyncAzureOpenAI(**prepare_openai_kwargs(azure_kwargs))
guardrail_llm = AsyncAzureOpenAI(**azure_kwargs)
logger.info("Created Azure OpenAI client for endpoint: %s", self.azure_endpoint)
# OpenAI or OpenAI-compatible API
else:
Expand All @@ -292,7 +291,7 @@ def _create_context(self) -> Context:
openai_kwargs["base_url"] = self.base_url
logger.info("Created OpenAI-compatible client for base_url: %s", self.base_url)

guardrail_llm = AsyncOpenAI(**prepare_openai_kwargs(openai_kwargs))
guardrail_llm = AsyncOpenAI(**openai_kwargs)

return Context(guardrail_llm=guardrail_llm)

Expand Down
62 changes: 53 additions & 9 deletions src/guardrails/resources/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,38 @@

from ..._base_client import GuardrailsBaseClient

# OpenAI safety identifier for tracking guardrails library usage
# Only supported by official OpenAI API (not Azure or local/alternative providers)
_SAFETY_IDENTIFIER = "oai_guardrails"


def _supports_safety_identifier(client: Any) -> bool:
"""Check if the client supports the safety_identifier parameter.

Only the official OpenAI API supports this parameter.
Azure OpenAI and local/alternative providers do not.

Args:
client: The OpenAI client instance.

Returns:
True if safety_identifier should be included, False otherwise.
"""
# Azure clients don't support it
client_type = type(client).__name__
if "Azure" in client_type:
return False

# Check if using a custom base_url (local or alternative provider)
base_url = getattr(client, "base_url", None)
if base_url is not None:
base_url_str = str(base_url)
# Only official OpenAI API endpoints support safety_identifier
return "api.openai.com" in base_url_str

# Default OpenAI client (no custom base_url) supports it
return True


class Chat:
"""Chat completions with guardrails (sync)."""
Expand Down Expand Up @@ -82,12 +114,19 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals

# Run input guardrails and LLM call concurrently using a thread for the LLM
with ThreadPoolExecutor(max_workers=1) as executor:
# Only include safety_identifier for OpenAI clients (not Azure)
llm_kwargs = {
"messages": modified_messages,
"model": model,
"stream": stream,
**kwargs,
}
if _supports_safety_identifier(self._client._resource_client):
llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER

llm_future = executor.submit(
self._client._resource_client.chat.completions.create,
messages=modified_messages, # Use messages with any preflight modifications
model=model,
stream=stream,
**kwargs,
**llm_kwargs,
)
input_results = self._client._run_stage_guardrails(
"input",
Expand Down Expand Up @@ -152,12 +191,17 @@ async def create(
conversation_history=normalized_conversation,
suppress_tripwire=suppress_tripwire,
)
llm_call = self._client._resource_client.chat.completions.create(
messages=modified_messages, # Use messages with any preflight modifications
model=model,
stream=stream,
# Only include safety_identifier for OpenAI clients (not Azure)
llm_kwargs = {
"messages": modified_messages,
"model": model,
"stream": stream,
**kwargs,
)
}
if _supports_safety_identifier(self._client._resource_client):
llm_kwargs["safety_identifier"] = _SAFETY_IDENTIFIER

llm_call = self._client._resource_client.chat.completions.create(**llm_kwargs)

input_results, llm_response = await asyncio.gather(input_check, llm_call)

Expand Down
Loading
Loading