From e17edb209e256e77913cb13e0d0c3c1e6a0c6679 Mon Sep 17 00:00:00 2001 From: Pouyanpi <13303554+Pouyanpi@users.noreply.github.com> Date: Tue, 9 Sep 2025 12:29:08 +0200 Subject: [PATCH] feat(tool-rails): add support for tool output rails and validation Introduce tool output/input rails configuration and Colang flows for tool call validation and parameter security checks. Add support for BotToolCall event emission in passthrough mode, enabling tool call guardrails before execution. --- nemoguardrails/actions/llm/generation.py | 16 +- nemoguardrails/actions/llm/utils.py | 15 + nemoguardrails/rails/llm/config.py | 42 ++ nemoguardrails/rails/llm/llm_flows.co | 50 ++ nemoguardrails/rails/llm/llmrails.py | 4 +- nemoguardrails/rails/llm/options.py | 12 + tests/runnable_rails/test_tool_calling.py | 268 +++++++++- tests/test_bot_tool_call_events.py | 274 ++++++++++ tests/test_output_rails_tool_calls.py | 304 +++++++++++ ...st_tool_calling_passthrough_integration.py | 28 +- tests/test_tool_calling_passthrough_only.py | 216 ++++++++ tests/test_tool_calls_event_extraction.py | 506 ++++++++++++++++++ tests/test_tool_output_rails.py | 243 +++++++++ 13 files changed, 1943 insertions(+), 35 deletions(-) create mode 100644 tests/test_bot_tool_call_events.py create mode 100644 tests/test_output_rails_tool_calls.py create mode 100644 tests/test_tool_calling_passthrough_only.py create mode 100644 tests/test_tool_calls_event_extraction.py create mode 100644 tests/test_tool_output_rails.py diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 2a57e1c26..f54761aea 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -582,7 +582,21 @@ async def generate_user_intent( if streaming_handler: await streaming_handler.push_chunk(text) - output_events.append(new_event_dict("BotMessage", text=text)) + if self.config.passthrough: + from nemoguardrails.actions.llm.utils import ( + get_and_clear_tool_calls_contextvar, + ) + + tool_calls = get_and_clear_tool_calls_contextvar() + + if tool_calls: + output_events.append( + new_event_dict("BotToolCall", tool_calls=tool_calls) + ) + else: + output_events.append(new_event_dict("BotMessage", text=text)) + else: + output_events.append(new_event_dict("BotMessage", text=text)) return ActionResult(events=output_events) diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index 3b1dbd062..d357037da 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -673,6 +673,21 @@ def get_and_clear_tool_calls_contextvar() -> Optional[list]: return None +def extract_tool_calls_from_events(events: list) -> Optional[list]: + """Extract tool_calls from BotToolCall events. + + Args: + events: List of events to search through + + Returns: + tool_calls if found in BotToolCall event, None otherwise + """ + for event in events: + if event.get("type") == "BotToolCall": + return event.get("tool_calls") + return None + + def get_and_clear_response_metadata_contextvar() -> Optional[dict]: """Get the current response metadata and clear it from the context. diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index bc12569a1..76b9f92e1 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -527,6 +527,40 @@ class ActionRails(BaseModel): ) +class ToolOutputRails(BaseModel): + """Configuration of tool output rails. + + Tool output rails are applied to tool calls before they are executed. + They can validate tool names, parameters, and context to ensure safe tool usage. + """ + + flows: List[str] = Field( + default_factory=list, + description="The names of all the flows that implement tool output rails.", + ) + parallel: Optional[bool] = Field( + default=False, + description="If True, the tool output rails are executed in parallel.", + ) + + +class ToolInputRails(BaseModel): + """Configuration of tool input rails. + + Tool input rails are applied to tool results before they are processed. + They can validate, filter, or transform tool outputs for security and safety. + """ + + flows: List[str] = Field( + default_factory=list, + description="The names of all the flows that implement tool input rails.", + ) + parallel: Optional[bool] = Field( + default=False, + description="If True, the tool input rails are executed in parallel.", + ) + + class SingleCallConfig(BaseModel): """Configuration for the single LLM call option for topical rails.""" @@ -912,6 +946,14 @@ class Rails(BaseModel): actions: ActionRails = Field( default_factory=ActionRails, description="Configuration of action rails." ) + tool_output: ToolOutputRails = Field( + default_factory=ToolOutputRails, + description="Configuration of tool output rails.", + ) + tool_input: ToolInputRails = Field( + default_factory=ToolInputRails, + description="Configuration of tool input rails.", + ) def merge_two_dicts(dict_1: dict, dict_2: dict, ignore_keys: Set[str]) -> None: diff --git a/nemoguardrails/rails/llm/llm_flows.co b/nemoguardrails/rails/llm/llm_flows.co index 63edb266b..9c4d87372 100644 --- a/nemoguardrails/rails/llm/llm_flows.co +++ b/nemoguardrails/rails/llm/llm_flows.co @@ -102,6 +102,34 @@ define parallel extension flow generate bot message execute generate_bot_message +define parallel extension flow process bot tool call + """Processes tool calls from the bot.""" + priority 100 + + event BotToolCall + + $tool_calls = $event.tool_calls + + # Run tool-specific output rails if configured (Phase 2) + if $config.rails.tool_output.flows + # If we have generation options, we make sure the tool output rails are enabled. + if $generation_options is None or $generation_options.rails.tool_output: + # Create a marker event. + create event StartToolOutputRails + event StartToolOutputRails + + # Run all the tool output rails + # This can potentially alter or block the tool calls + do run tool output rails + + # Create a marker event. + create event ToolOutputRailsFinished + event ToolOutputRailsFinished + + # Create the action event for tool execution + create event StartToolCallBotAction(tool_calls=$tool_calls) + + define parallel extension flow process bot message """Runs the output rails on a bot message.""" priority 100 @@ -164,3 +192,25 @@ define subflow run retrieval rails while $i < len($retrieval_flows) do $retrieval_flows[$i] $i = $i + 1 + + +define subflow run tool output rails + """Runs all the tool output rails in a sequential order.""" + $tool_output_flows = $config.rails.tool_output.flows + + $i = 0 + while $i < len($tool_output_flows) + # We set the current rail as being triggered. + $triggered_tool_output_rail = $tool_output_flows[$i] + + create event StartToolOutputRail(flow_id=$triggered_tool_output_rail) + event StartToolOutputRail + + do $tool_output_flows[$i] + $i = $i + 1 + + create event ToolOutputRailFinished(flow_id=$triggered_tool_output_rail) + event ToolOutputRailFinished + + # If all went smooth, we remove it. + $triggered_tool_output_rail = None diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 6a229e829..4d205ef9b 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -32,9 +32,9 @@ from nemoguardrails.actions.llm.generation import LLMGenerationActions from nemoguardrails.actions.llm.utils import ( + extract_tool_calls_from_events, get_and_clear_reasoning_trace_contextvar, get_and_clear_response_metadata_contextvar, - get_and_clear_tool_calls_contextvar, get_colang_history, ) from nemoguardrails.actions.output_mapping import is_output_blocked @@ -1086,7 +1086,7 @@ async def generate_async( options.log.llm_calls = True options.log.internal_events = True - tool_calls = get_and_clear_tool_calls_contextvar() + tool_calls = extract_tool_calls_from_events(new_events) llm_metadata = get_and_clear_response_metadata_contextvar() # If we have generation options, we prepare a GenerationResponse instance. diff --git a/nemoguardrails/rails/llm/options.py b/nemoguardrails/rails/llm/options.py index 3ccb054e9..dd9f87099 100644 --- a/nemoguardrails/rails/llm/options.py +++ b/nemoguardrails/rails/llm/options.py @@ -127,6 +127,16 @@ class GenerationRailsOptions(BaseModel): default=True, description="Whether the dialog rails are enabled or not.", ) + tool_output: Union[bool, List[str]] = Field( + default=True, + description="Whether the tool output rails are enabled or not. " + "If a list of names is specified, then only the specified tool output rails will be applied.", + ) + tool_input: Union[bool, List[str]] = Field( + default=True, + description="Whether the tool input rails are enabled or not. " + "If a list of names is specified, then only the specified tool input rails will be applied.", + ) class GenerationOptions(BaseModel): @@ -177,6 +187,8 @@ def check_fields(cls, values): "dialog": False, "retrieval": False, "output": False, + "tool_output": False, + "tool_input": False, } for rail_type in values["rails"]: _rails[rail_type] = True diff --git a/tests/runnable_rails/test_tool_calling.py b/tests/runnable_rails/test_tool_calling.py index ebf658795..fb42f357c 100644 --- a/tests/runnable_rails/test_tool_calling.py +++ b/tests/runnable_rails/test_tool_calling.py @@ -101,7 +101,7 @@ async def ainvoke(self, messages, **kwargs): result = chain.invoke({"input": "What's the weather?"}) assert isinstance(result, AIMessage) - assert result.content == "I'll check the weather for you." + assert result.content == "" assert result.tool_calls is not None assert len(result.tool_calls) == 1 assert result.tool_calls[0]["name"] == "get_weather" @@ -170,27 +170,257 @@ async def ainvoke(self, messages, **kwargs): assert result["output"].tool_calls[0]["name"] == "test_tool" -@pytest.mark.skipif( - not has_nvidia_ai_endpoints(), - reason="langchain-nvidia-ai-endpoints package not installed", -) -def test_runnable_binding_treated_as_llm(): - """Test that RunnableBinding with LLM tools is treated as an LLM, not passthrough_runnable.""" - from langchain_core.tools import tool - from langchain_nvidia_ai_endpoints import ChatNVIDIA +def test_tool_calls_with_output_rails(): + """Test that tool calls bypass output rails and don't get blocked.""" - @tool - def get_weather(city: str) -> str: - """Get weather for a given city.""" - return f"It's sunny in {city}!" + class MockLLMWithForcedTools: + def invoke(self, messages, **kwargs): + # simulate enforced tool choice which returns empty content with tool_calls + return AIMessage( + content="", + tool_calls=[ + { + "name": "test_tool", + "args": {"param": "value"}, + "id": "call_test123", + "type": "tool_call", + } + ], + ) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + config = RailsConfig.from_content( + """ + define flow block_empty_output + if $bot_message == "" + bot refuse to respond + stop + """, + """ + rails: + output: + flows: + - block_empty_output + """, + ) + + rails = RunnableRails(config, llm=MockLLMWithForcedTools()) + result = rails.invoke(HumanMessage(content="Test tool call")) + + assert isinstance(result, AIMessage) + assert result.content != "I'm sorry, I can't respond to that." + assert result.tool_calls is not None + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "test_tool" + + +def test_empty_content_with_tool_calls_not_blocked(): + """Test that empty content with tool_calls doesn't trigger refuse to respond.""" + + class MockLLMWithEmptyContentAndTools: + def invoke(self, messages, **kwargs): + return AIMessage( + content="", + tool_calls=[ + { + "name": "gather_info", + "args": {"name": "John", "dob": "1990-01-01"}, + "id": "call_gather123", + "type": "tool_call", + } + ], + ) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + config = RailsConfig.from_content( + colang_content="", + yaml_content=""" + models: [] + rails: + output: + flows: + - self check output + + prompts: + - task: self_check_output + content: | + Instructions: {instructions} + Output: {output} + + Check if the output is appropriate and safe. + """, + ) + + rails = RunnableRails(config, llm=MockLLMWithEmptyContentAndTools()) + result = rails.invoke(HumanMessage(content="Test message")) + + assert result.tool_calls is not None + assert len(result.tool_calls) == 1 + assert "I'm sorry, I can't respond to that." not in result.content + + +def test_bot_tool_call_event_creation(): + """Test that BotToolCall events are created instead of BotMessage when tool_calls exist.""" + + class MockLLMReturningToolCall: + def invoke(self, messages, **kwargs): + return AIMessage( + content="", + tool_calls=[ + { + "name": "weather_tool", + "args": {"location": "NYC"}, + "id": "call_weather456", + "type": "tool_call", + } + ], + ) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) config = RailsConfig.from_content(config={"models": []}) - guardrails = RunnableRails(config=config, passthrough=True) + rails = RunnableRails(config, llm=MockLLMReturningToolCall()) - llm = ChatNVIDIA(model="meta/llama-3.3-70b-instruct") - llm_with_tools = llm.bind_tools([get_weather]) + result = rails.invoke(HumanMessage(content="Get weather")) - piped = guardrails | llm_with_tools + assert isinstance(result, AIMessage) + assert result.tool_calls is not None + assert result.tool_calls[0]["name"] == "weather_tool" + assert result.tool_calls[0]["args"]["location"] == "NYC" - assert piped.llm is llm_with_tools - assert piped.passthrough_runnable is None + +def test_tool_calls_enforced_choice(): + """Test enforced tool_choice scenario that was originally failing.""" + + class MockLLMWithEnforcedTool: + def invoke(self, messages, **kwargs): + # simulates bind_tools with tool_choice - always calls specific tool + return AIMessage( + content="", + tool_calls=[ + { + "name": "print_gathered_patient_info", + "args": { + "patient_name": "John Doe", + "patient_dob": "01/01/1990", + }, + "id": "call_patient789", + "type": "tool_call", + } + ], + ) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + config = RailsConfig.from_content( + colang_content="", + yaml_content=""" + models: [] + rails: + output: + flows: + - self check output + + prompts: + - task: self_check_output + content: | + Instructions: {instructions} + Output: {output} + + Check if the output is appropriate and safe. + """, + ) + + rails = RunnableRails(config, llm=MockLLMWithEnforcedTool()) + result = rails.invoke(HumanMessage(content="Hi!")) + + assert result.tool_calls is not None + assert result.tool_calls[0]["name"] == "print_gathered_patient_info" + assert result.content == "" + assert "I'm sorry, I can't respond to that." not in result.content + + +def test_complex_chain_with_tool_calls(): + """Test tool calls work in complex LangChain scenarios.""" + + class MockPatientIntakeLLM: + def invoke(self, messages, **kwargs): + return AIMessage( + content="", + tool_calls=[ + { + "name": "print_gathered_patient_info", + "args": { + "patient_name": "John Doe", + "patient_dob": "01/01/1990", + }, + "id": "call_intake", + "type": "tool_call", + } + ], + ) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + system_prompt = """ + You are a specialized assistant for handling patient intake. + After gathering all information, use the print_gathered_patient_info tool. + """ + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + ("placeholder", "{messages}"), + ] + ) + + config = RailsConfig.from_content( + colang_content="", + yaml_content=""" + models: [] + rails: + output: + flows: + - self check output + + prompts: + - task: self_check_output + content: | + Instructions: {instructions} + Output: {output} + + Check if the output is appropriate and safe. + """, + ) + + guardrails = RunnableRails( + config=config, llm=MockPatientIntakeLLM(), passthrough=True + ) + + chain = prompt | guardrails + + result = chain.invoke( + { + "messages": [ + ("user", "Hi!"), + ("assistant", "Welcome! What's your name?"), + ("user", "My name is John Doe."), + ("assistant", "What's your date of birth?"), + ("user", "My date of birth is 01/01/1990."), + ] + } + ) + + assert isinstance(result, AIMessage) + assert result.tool_calls is not None + assert result.tool_calls[0]["name"] == "print_gathered_patient_info" + assert result.tool_calls[0]["args"]["patient_name"] == "John Doe" + assert result.content == "" + assert "I'm sorry, I can't respond to that." not in result.content diff --git a/tests/test_bot_tool_call_events.py b/tests/test_bot_tool_call_events.py new file mode 100644 index 000000000..400432e55 --- /dev/null +++ b/tests/test_bot_tool_call_events.py @@ -0,0 +1,274 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for BotToolCall event handling in NeMo Guardrails.""" + +from unittest.mock import patch + +import pytest + +from nemoguardrails import RailsConfig +from tests.utils import TestChat + + +@pytest.mark.asyncio +async def test_bot_tool_call_event_creation(): + """Test that BotToolCall events are created when tool_calls are present.""" + + test_tool_calls = [ + { + "name": "test_function", + "args": {"param1": "value1"}, + "id": "call_12345", + "type": "tool_call", + } + ] + + with patch( + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = test_tool_calls + + config = RailsConfig.from_content(config={"models": [], "passthrough": True}) + chat = TestChat(config, llm_completions=[""]) + + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "Test"}] + ) + + assert result["tool_calls"] is not None + assert len(result["tool_calls"]) == 1 + assert result["tool_calls"][0]["name"] == "test_function" + + +@pytest.mark.asyncio +async def test_bot_message_vs_bot_tool_call_event(): + """Test that regular text creates BotMessage, tool calls create BotToolCall.""" + + config = RailsConfig.from_content(config={"models": [], "passthrough": True}) + + with patch( + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = None + + chat_text = TestChat(config, llm_completions=["Regular text response"]) + result_text = await chat_text.app.generate_async( + messages=[{"role": "user", "content": "Hello"}] + ) + + assert result_text["content"] == "Regular text response" + assert ( + result_text.get("tool_calls") is None or result_text.get("tool_calls") == [] + ) + + test_tool_calls = [ + { + "name": "toggle_tool", + "args": {}, + "id": "call_toggle", + "type": "tool_call", + } + ] + + with patch( + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = test_tool_calls + + chat_tools = TestChat(config, llm_completions=[""]) + result_tools = await chat_tools.app.generate_async( + messages=[{"role": "user", "content": "Use tool"}] + ) + + assert result_tools["tool_calls"] is not None + assert result_tools["tool_calls"][0]["name"] == "toggle_tool" + + +@pytest.mark.asyncio +async def test_tool_calls_bypass_output_rails(): + """Test that tool calls bypass output rails in passthrough mode.""" + + test_tool_calls = [ + { + "name": "critical_tool", + "args": {"action": "execute"}, + "id": "call_critical", + "type": "tool_call", + } + ] + + config = RailsConfig.from_content( + """ + define flow block_empty_content + if $bot_message == "" + bot refuse to respond + stop + """, + """ + models: [] + passthrough: true + rails: + output: + flows: + - block_empty_content + """, + ) + + with patch( + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = test_tool_calls + + chat = TestChat(config, llm_completions=[""]) + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "Execute"}] + ) + + assert result["tool_calls"] is not None + assert result["tool_calls"][0]["name"] == "critical_tool" + + +@pytest.mark.asyncio +async def test_mixed_content_and_tool_calls(): + """Test responses that have both content and tool calls.""" + + test_tool_calls = [ + { + "name": "transmit_data", + "args": {"info": "user_data"}, + "id": "call_transmit", + "type": "tool_call", + } + ] + + config = RailsConfig.from_content(config={"models": [], "passthrough": True}) + + with patch( + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = test_tool_calls + + chat = TestChat( + config, + llm_completions=["I found the information and will now transmit it."], + ) + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "Process data"}] + ) + + assert result["tool_calls"] is not None + assert result["tool_calls"][0]["name"] == "transmit_data" + + +@pytest.mark.asyncio +async def test_multiple_tool_calls(): + """Test handling of multiple tool calls in a single response.""" + + test_tool_calls = [ + { + "name": "tool_one", + "args": {"param": "first"}, + "id": "call_one", + "type": "tool_call", + }, + { + "name": "tool_two", + "args": {"param": "second"}, + "id": "call_two", + "type": "tool_call", + }, + ] + + config = RailsConfig.from_content(config={"models": [], "passthrough": True}) + + with patch( + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = test_tool_calls + + chat = TestChat(config, llm_completions=[""]) + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "Execute multiple tools"}] + ) + + assert result["tool_calls"] is not None + assert len(result["tool_calls"]) == 2 + assert result["tool_calls"][0]["name"] == "tool_one" + assert result["tool_calls"][1]["name"] == "tool_two" + + +@pytest.mark.asyncio +async def test_regular_text_still_goes_through_output_rails(): + """Test that regular text responses still go through output rails.""" + + config = RailsConfig.from_content( + """ + define flow add_prefix + $bot_message = "PREFIX: " + $bot_message + """, + """ + rails: + output: + flows: + - add_prefix + """, + ) + + with patch( + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = None + + chat = TestChat(config, llm_completions=["This is a regular response"]) + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "Say something"}] + ) + + assert "PREFIX: This is a regular response" in result["content"] + assert result.get("tool_calls") is None or result.get("tool_calls") == [] + + +@pytest.mark.asyncio +async def test_empty_text_without_tool_calls_still_blocked(): + """Test that empty text without tool_calls is still blocked by output rails.""" + + config = RailsConfig.from_content( + """ + define flow block_empty + if $bot_message == "" + bot refuse to respond + stop + """, + """ + rails: + output: + flows: + - block_empty + """, + ) + + with patch( + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = None + + chat = TestChat(config, llm_completions=[""]) + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "Say something"}] + ) + + assert "I'm sorry, I can't respond to that." in result["content"] + assert result.get("tool_calls") is None or result.get("tool_calls") == [] diff --git a/tests/test_output_rails_tool_calls.py b/tests/test_output_rails_tool_calls.py new file mode 100644 index 000000000..36288b702 --- /dev/null +++ b/tests/test_output_rails_tool_calls.py @@ -0,0 +1,304 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for tool calls with output rails.""" + +import pytest +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.prompts import ChatPromptTemplate + +from nemoguardrails import RailsConfig +from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails + + +def test_output_rails_skip_for_tool_calls(): + """Test that output rails are skipped when tool calls are present.""" + + class MockLLMWithToolResponse: + def invoke(self, messages, **kwargs): + return AIMessage( + content="", + tool_calls=[ + { + "name": "process_data", + "args": {"data": "test"}, + "id": "call_process", + "type": "tool_call", + } + ], + ) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + # Config with aggressive output rails that would block empty content + config = RailsConfig.from_content( + """ + define flow strict_output_check + if $bot_message == "" + bot refuse to respond + stop + + define flow add_prefix + $bot_message = "PREFIX: " + $bot_message + """, + """ + rails: + output: + flows: + - strict_output_check + - add_prefix + """, + ) + + rails = RunnableRails(config, llm=MockLLMWithToolResponse()) + result = rails.invoke(HumanMessage(content="Process this")) + + # Tool calls should bypass output rails entirely + assert result.tool_calls is not None + assert result.tool_calls[0]["name"] == "process_data" + assert result.content == "" # Should stay empty, not modified by rails + assert "I'm sorry, I can't respond to that." not in result.content + assert "PREFIX:" not in result.content # Rails should not have run + + +def test_text_responses_still_use_output_rails(): + """Test that regular text responses still go through output rails.""" + + class MockLLMTextResponse: + def invoke(self, messages, **kwargs): + return AIMessage(content="Hello there") + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + # Same config as above test + config = RailsConfig.from_content( + """ + define flow add_prefix + $bot_message = "PREFIX: " + $bot_message + """, + """ + rails: + output: + flows: + - add_prefix + """, + ) + + rails = RunnableRails(config, llm=MockLLMTextResponse()) + result = rails.invoke(HumanMessage(content="Say hello")) + + assert "PREFIX: Hello there" in result.content + assert result.tool_calls is None or result.tool_calls == [] + + +def test_complex_chain_with_tool_calls(): + """Test tool calls work in complex LangChain scenarios.""" + + class MockPatientIntakeLLM: + def invoke(self, messages, **kwargs): + return AIMessage( + content="", + tool_calls=[ + { + "name": "print_gathered_patient_info", + "args": { + "patient_name": "John Doe", + "patient_dob": "01/01/1990", + }, + "id": "call_intake", + "type": "tool_call", + } + ], + ) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + system_prompt = """ + You are a specialized assistant for handling patient intake. + After gathering all information, use the print_gathered_patient_info tool. + """ + + prompt = ChatPromptTemplate.from_messages( + [ + ("system", system_prompt), + ("placeholder", "{messages}"), + ] + ) + + config = RailsConfig.from_content( + colang_content="", + yaml_content=""" + models: [] + rails: + output: + flows: + - self check output + + prompts: + - task: self_check_output + content: | + Instructions: {instructions} + Output: {output} + + Check if the output is appropriate and safe. + """, + ) + + guardrails = RunnableRails( + config=config, llm=MockPatientIntakeLLM(), passthrough=True + ) + + chain = prompt | guardrails + + result = chain.invoke( + { + "messages": [ + ("user", "Hi!"), + ("assistant", "Welcome! What's your name?"), + ("user", "My name is John Doe."), + ("assistant", "What's your date of birth?"), + ("user", "My date of birth is 01/01/1990."), + ] + } + ) + + assert isinstance(result, AIMessage) + assert result.tool_calls is not None + assert result.tool_calls[0]["name"] == "print_gathered_patient_info" + assert result.tool_calls[0]["args"]["patient_name"] == "John Doe" + assert result.content == "" + assert "I'm sorry, I can't respond to that." not in result.content + + +def test_self_check_output_rail_bypassed(): + """Test that self_check_output rail is bypassed for tool calls.""" + + class MockLLMToolCallsWithSelfCheck: + def invoke(self, messages, **kwargs): + return AIMessage( + content="", + tool_calls=[ + { + "name": "sensitive_operation", + "args": {"action": "process"}, + "id": "call_sensitive", + "type": "tool_call", + } + ], + ) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + config = RailsConfig.from_content( + colang_content="", + yaml_content=""" + models: [] + rails: + output: + flows: + - self check output + + prompts: + - task: self_check_output + content: | + Instructions: {instructions} + Output: {output} + + Check if the output is appropriate and safe. + """, + ) + + rails = RunnableRails(config, llm=MockLLMToolCallsWithSelfCheck()) + result = rails.invoke(HumanMessage(content="Perform sensitive operation")) + + assert result.tool_calls is not None + assert result.tool_calls[0]["name"] == "sensitive_operation" + assert "I'm sorry, I can't respond to that." not in result.content + + +def test_backward_compatibility_text_blocking(): + """Test that text-based blocking still works for non-tool responses.""" + + class MockLLMProblematicText: + def invoke(self, messages, **kwargs): + return AIMessage(content="This response should be blocked by output rails") + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + config = RailsConfig.from_content( + """ + define flow block_problematic + if "should be blocked" in $bot_message + bot refuse to respond + stop + """, + """ + rails: + output: + flows: + - block_problematic + """, + ) + + rails = RunnableRails(config, llm=MockLLMProblematicText()) + result = rails.invoke(HumanMessage(content="Say something bad")) + + assert "I'm sorry, I can't respond to that." in result.content + assert result.tool_calls is None or result.tool_calls == [] + + +def test_mixed_tool_calls_and_content(): + """Test responses that have both content and tool calls.""" + + class MockLLMWithBoth: + def invoke(self, messages, **kwargs): + return AIMessage( + content="I'll gather the information for you.", + tool_calls=[ + { + "name": "gather_info", + "args": {"user_id": "123"}, + "id": "call_gather", + "type": "tool_call", + } + ], + ) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + config = RailsConfig.from_content( + """ + define flow add_timestamp + $bot_message = $bot_message + " [" + $current_time + "]" + """, + """ + rails: + output: + flows: + - add_timestamp + """, + ) + + rails = RunnableRails(config, llm=MockLLMWithBoth()) + result = rails.invoke(HumanMessage(content="Gather my info")) + + assert result.tool_calls is not None + assert result.tool_calls[0]["name"] == "gather_info" diff --git a/tests/test_tool_calling_passthrough_integration.py b/tests/test_tool_calling_passthrough_integration.py index ca1689b97..886213553 100644 --- a/tests/test_tool_calling_passthrough_integration.py +++ b/tests/test_tool_calling_passthrough_integration.py @@ -51,13 +51,13 @@ async def test_tool_calls_work_in_passthrough_mode_with_options(self): ] with patch( - "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" ) as mock_get_clear: mock_get_clear.return_value = test_tool_calls chat = TestChat( self.passthrough_config, - llm_completions=["I'll help you with the weather and calculation."], + llm_completions=[""], ) result = await chat.app.generate_async( @@ -75,7 +75,7 @@ async def test_tool_calls_work_in_passthrough_mode_with_options(self): assert len(result.tool_calls) == 2 assert isinstance(result.response, list) assert result.response[0]["role"] == "assistant" - assert "help you" in result.response[0]["content"] + assert result.response[0]["content"] == "" @pytest.mark.asyncio async def test_tool_calls_work_in_passthrough_mode_dict_response(self): @@ -89,7 +89,7 @@ async def test_tool_calls_work_in_passthrough_mode_dict_response(self): ] with patch( - "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" ) as mock_get_clear: mock_get_clear.return_value = test_tool_calls @@ -106,12 +106,12 @@ async def test_tool_calls_work_in_passthrough_mode_dict_response(self): assert "tool_calls" in result assert result["tool_calls"] == test_tool_calls assert result["role"] == "assistant" - assert "check the weather" in result["content"] + assert result["content"] == "" @pytest.mark.asyncio async def test_no_tool_calls_in_passthrough_mode(self): with patch( - "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" ) as mock_get_clear: mock_get_clear.return_value = None @@ -132,7 +132,7 @@ async def test_no_tool_calls_in_passthrough_mode(self): @pytest.mark.asyncio async def test_empty_tool_calls_in_passthrough_mode(self): with patch( - "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" ) as mock_get_clear: mock_get_clear.return_value = [] @@ -160,12 +160,14 @@ async def test_tool_calls_with_prompt_mode_passthrough(self): ] with patch( - "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" ) as mock_get_clear: mock_get_clear.return_value = test_tool_calls chat = TestChat( self.passthrough_config, + # note that llm would not generate any content when tool calls are present + # this is here just to show the underlying behavior llm_completions=["I'll search for that information."], ) @@ -176,7 +178,7 @@ async def test_tool_calls_with_prompt_mode_passthrough(self): assert isinstance(result, GenerationResponse) assert result.tool_calls == test_tool_calls assert isinstance(result.response, str) - assert "search for that information" in result.response + assert result.response == "" @pytest.mark.asyncio async def test_complex_tool_calls_passthrough_integration(self): @@ -202,7 +204,7 @@ async def test_complex_tool_calls_passthrough_integration(self): ] with patch( - "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" ) as mock_get_clear: mock_get_clear.return_value = complex_tool_calls @@ -277,7 +279,7 @@ async def test_tool_calls_integration_preserves_other_response_data(self): ] with patch( - "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" ) as mock_get_clear: mock_get_clear.return_value = test_tool_calls @@ -299,7 +301,7 @@ async def test_tool_calls_integration_preserves_other_response_data(self): assert isinstance(result.response, list) assert len(result.response) == 1 assert result.response[0]["role"] == "assistant" - assert result.response[0]["content"] == "Response with preserved data." + assert result.response[0]["content"] == "" @pytest.mark.asyncio async def test_tool_calls_with_real_world_examples(self): @@ -319,7 +321,7 @@ async def test_tool_calls_with_real_world_examples(self): ] with patch( - "nemoguardrails.rails.llm.llmrails.get_and_clear_tool_calls_contextvar" + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" ) as mock_get_clear: mock_get_clear.return_value = realistic_tool_calls diff --git a/tests/test_tool_calling_passthrough_only.py b/tests/test_tool_calling_passthrough_only.py new file mode 100644 index 000000000..5791d5b0f --- /dev/null +++ b/tests/test_tool_calling_passthrough_only.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that tool calling ONLY works in passthrough mode.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from langchain_core.messages import AIMessage + +from nemoguardrails import LLMRails, RailsConfig +from nemoguardrails.actions.llm.generation import LLMGenerationActions +from nemoguardrails.context import tool_calls_var + + +@pytest.fixture +def mock_llm_with_tool_calls(): + """Mock LLM that returns tool calls.""" + llm = AsyncMock() + + mock_response = AIMessage( + content="", + tool_calls=[ + { + "id": "call_123", + "type": "tool_call", + "name": "test_tool", + "args": {"param": "value"}, + } + ], + ) + llm.ainvoke.return_value = mock_response + llm.invoke.return_value = mock_response + return llm + + +@pytest.fixture +def config_passthrough(): + """Config with passthrough enabled.""" + return RailsConfig.from_content( + colang_content="", + yaml_content=""" + models: + - type: main + engine: mock + model: test-model + + rails: + input: + flows: [] + dialog: + flows: [] + output: + flows: [] + + passthrough: true + """, + ) + + +@pytest.fixture +def config_no_passthrough(): + """Config with passthrough disabled.""" + return RailsConfig.from_content( + colang_content="", + yaml_content=""" + models: + - type: main + engine: mock + model: test-model + + rails: + input: + flows: [] + dialog: + flows: [] + output: + flows: [] + + passthrough: false + """, + ) + + +class TestToolCallingPassthroughOnly: + """Test that tool calling only works in passthrough mode.""" + + def test_config_passthrough_true(self, config_passthrough): + """Test that passthrough config is correctly set.""" + assert config_passthrough.passthrough is True + + def test_config_passthrough_false(self, config_no_passthrough): + """Test that non-passthrough config is correctly set.""" + assert config_no_passthrough.passthrough is False + + @pytest.mark.asyncio + async def test_tool_calls_work_in_passthrough_mode( + self, config_passthrough, mock_llm_with_tool_calls + ): + """Test that tool calls create BotToolCall events in passthrough mode.""" + tool_calls = [ + { + "id": "call_123", + "type": "tool_call", + "name": "test_tool", + "args": {"param": "value"}, + } + ] + tool_calls_var.set(tool_calls) + + generation_actions = LLMGenerationActions( + config=config_passthrough, + llm=mock_llm_with_tool_calls, + llm_task_manager=MagicMock(), + get_embedding_search_provider_instance=MagicMock(return_value=None), + ) + + events = [{"type": "UserMessage", "text": "test"}] + context = {} + + result = await generation_actions.generate_user_intent( + events=events, context=context, config=config_passthrough + ) + + assert len(result.events) == 1 + assert result.events[0]["type"] == "BotToolCall" + assert result.events[0]["tool_calls"] == tool_calls + + @pytest.mark.asyncio + async def test_tool_calls_ignored_in_non_passthrough_mode( + self, config_no_passthrough, mock_llm_with_tool_calls + ): + """Test that tool calls are ignored when not in passthrough mode.""" + tool_calls = [ + { + "id": "call_123", + "type": "tool_call", + "name": "test_tool", + "args": {"param": "value"}, + } + ] + tool_calls_var.set(tool_calls) + + generation_actions = LLMGenerationActions( + config=config_no_passthrough, + llm=mock_llm_with_tool_calls, + llm_task_manager=MagicMock(), + get_embedding_search_provider_instance=MagicMock(return_value=None), + ) + + events = [{"type": "UserMessage", "text": "test"}] + context = {} + + result = await generation_actions.generate_user_intent( + events=events, context=context, config=config_no_passthrough + ) + + assert len(result.events) == 1 + assert result.events[0]["type"] == "BotMessage" + assert "tool_calls" not in result.events[0] + + @pytest.mark.asyncio + async def test_no_tool_calls_creates_bot_message_in_passthrough( + self, config_passthrough, mock_llm_with_tool_calls + ): + """Test that no tool calls creates BotMessage event even in passthrough mode.""" + tool_calls_var.set(None) + + mock_response_no_tools = AIMessage(content="Regular text response") + mock_llm_with_tool_calls.ainvoke.return_value = mock_response_no_tools + mock_llm_with_tool_calls.invoke.return_value = mock_response_no_tools + + generation_actions = LLMGenerationActions( + config=config_passthrough, + llm=mock_llm_with_tool_calls, + llm_task_manager=MagicMock(), + get_embedding_search_provider_instance=MagicMock(return_value=None), + ) + + events = [{"type": "UserMessage", "text": "test"}] + context = {} + + result = await generation_actions.generate_user_intent( + events=events, context=context, config=config_passthrough + ) + + assert len(result.events) == 1 + assert result.events[0]["type"] == "BotMessage" + + def test_llm_rails_integration_passthrough_mode( + self, config_passthrough, mock_llm_with_tool_calls + ): + """Test LLMRails with passthrough mode allows tool calls.""" + rails = LLMRails(config=config_passthrough, llm=mock_llm_with_tool_calls) + + assert rails.config.passthrough is True + + def test_llm_rails_integration_non_passthrough_mode( + self, config_no_passthrough, mock_llm_with_tool_calls + ): + """Test LLMRails without passthrough mode.""" + rails = LLMRails(config=config_no_passthrough, llm=mock_llm_with_tool_calls) + + assert rails.config.passthrough is False diff --git a/tests/test_tool_calls_event_extraction.py b/tests/test_tool_calls_event_extraction.py new file mode 100644 index 000000000..4a2d0f5fd --- /dev/null +++ b/tests/test_tool_calls_event_extraction.py @@ -0,0 +1,506 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for event-based tool_calls extraction.""" + +import pytest +from langchain_core.messages import AIMessage, HumanMessage + +from nemoguardrails import RailsConfig +from nemoguardrails.actions import action +from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails +from tests.utils import TestChat + + +@action(is_system_action=True) +async def validate_tool_parameters(tool_calls, context=None, **kwargs): + """Test implementation of tool parameter validation.""" + tool_calls = tool_calls or (context.get("tool_calls", []) if context else []) + + dangerous_patterns = ["eval", "exec", "system", "../", "rm -", "DROP", "DELETE"] + + for tool_call in tool_calls: + args = tool_call.get("args", {}) + for param_value in args.values(): + if isinstance(param_value, str): + if any( + pattern.lower() in param_value.lower() + for pattern in dangerous_patterns + ): + return False + return True + + +@action(is_system_action=True) +async def self_check_tool_calls(tool_calls, context=None, **kwargs): + """Test implementation of tool call validation.""" + tool_calls = tool_calls or (context.get("tool_calls", []) if context else []) + + return all( + isinstance(call, dict) and "name" in call and "id" in call + for call in tool_calls + ) + + +@pytest.mark.asyncio +async def test_tool_calls_preserved_when_rails_block(): + test_tool_calls = [ + { + "name": "dangerous_tool", + "args": {"param": "eval('malicious code')"}, + "id": "call_dangerous", + "type": "tool_call", + } + ] + + config = RailsConfig.from_content( + """ + define subflow validate tool parameters + $valid = execute validate_tool_parameters(tool_calls=$tool_calls) + + if not $valid + bot refuse dangerous tool parameters + abort + + define bot refuse dangerous tool parameters + "I cannot execute this tool request because the parameters may be unsafe." + """, + """ + models: [] + passthrough: true + rails: + tool_output: + flows: + - validate tool parameters + """, + ) + + class MockLLMWithDangerousTools: + def invoke(self, messages, **kwargs): + return AIMessage(content="", tool_calls=test_tool_calls) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + rails = RunnableRails(config, llm=MockLLMWithDangerousTools()) + + rails.rails.runtime.register_action( + validate_tool_parameters, name="validate_tool_parameters" + ) + rails.rails.runtime.register_action( + self_check_tool_calls, name="self_check_tool_calls" + ) + result = await rails.ainvoke(HumanMessage(content="Execute dangerous tool")) + + assert ( + result.tool_calls is not None + ), "tool_calls should be preserved in final response" + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "dangerous_tool" + assert "cannot execute this tool request" in result.content + + +@pytest.mark.asyncio +async def test_generation_action_pops_tool_calls_once(): + from unittest.mock import patch + + test_tool_calls = [ + { + "name": "test_tool", + "args": {"param": "value"}, + "id": "call_test", + "type": "tool_call", + } + ] + + config = RailsConfig.from_content(config={"models": [], "passthrough": True}) + + call_count = 0 + + def mock_get_and_clear(): + nonlocal call_count + call_count += 1 + if call_count == 1: + return test_tool_calls + return None + + with patch( + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar", + side_effect=mock_get_and_clear, + ): + chat = TestChat(config, llm_completions=[""]) + + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "Test"}] + ) + + assert call_count >= 1, "get_and_clear_tool_calls_contextvar should be called" + assert result["tool_calls"] is not None + assert result["tool_calls"][0]["name"] == "test_tool" + + +@pytest.mark.asyncio +async def test_llmrails_extracts_tool_calls_from_events(): + config = RailsConfig.from_content(config={"models": [], "passthrough": True}) + + test_tool_calls = [ + { + "name": "extract_test", + "args": {"data": "test"}, + "id": "call_extract", + "type": "tool_call", + } + ] + + mock_events = [ + {"type": "BotToolCall", "tool_calls": test_tool_calls, "uid": "test_uid"} + ] + + from nemoguardrails.actions.llm.utils import extract_tool_calls_from_events + + extracted_tool_calls = extract_tool_calls_from_events(mock_events) + + assert extracted_tool_calls is not None + assert len(extracted_tool_calls) == 1 + assert extracted_tool_calls[0]["name"] == "extract_test" + + +@pytest.mark.asyncio +async def test_tool_rails_cannot_clear_context_variable(): + from nemoguardrails.context import tool_calls_var + + test_tool_calls = [ + { + "name": "blocked_tool", + "args": {"param": "rm -rf /"}, + "id": "call_blocked", + "type": "tool_call", + } + ] + + tool_calls_var.set(test_tool_calls) + + context = {"tool_calls": test_tool_calls} + result = await validate_tool_parameters(test_tool_calls, context=context) + + assert result is False + assert ( + tool_calls_var.get() is not None + ), "Context variable should not be cleared by tool rails" + assert tool_calls_var.get()[0]["name"] == "blocked_tool" + + +@pytest.mark.asyncio +async def test_complete_fix_integration(): + """Integration test demonstrating the complete fix for tool_calls preservation.""" + + dangerous_tool_calls = [ + { + "name": "dangerous_function", + "args": {"code": "eval('malicious')"}, + "id": "call_dangerous_123", + "type": "tool_call", + } + ] + + config = RailsConfig.from_content( + """ + define subflow validate tool parameters + $valid = execute validate_tool_parameters(tool_calls=$tool_calls) + + if not $valid + bot refuse dangerous tool parameters + abort + + define bot refuse dangerous tool parameters + "I cannot execute this request due to security concerns." + """, + """ + models: [] + passthrough: true + rails: + tool_output: + flows: + - validate tool parameters + """, + ) + + class MockLLMReturningDangerousTools: + def invoke(self, messages, **kwargs): + return AIMessage(content="", tool_calls=dangerous_tool_calls) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + rails = RunnableRails(config, llm=MockLLMReturningDangerousTools()) + + rails.rails.runtime.register_action( + validate_tool_parameters, name="validate_tool_parameters" + ) + rails.rails.runtime.register_action( + self_check_tool_calls, name="self_check_tool_calls" + ) + result = await rails.ainvoke(HumanMessage(content="Run dangerous code")) + + assert "security concerns" in result.content + + assert result.tool_calls is not None, "tool_calls preserved despite being blocked" + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "dangerous_function" + + +@pytest.mark.asyncio +async def test_passthrough_mode_with_multiple_tool_calls(): + test_tool_calls = [ + { + "name": "get_weather", + "args": {"location": "NYC"}, + "id": "call_123", + "type": "tool_call", + }, + { + "name": "calculate", + "args": {"a": 2, "b": 2}, + "id": "call_456", + "type": "tool_call", + }, + ] + + config = RailsConfig.from_content(config={"models": [], "passthrough": True}) + + class MockLLMWithMultipleTools: + def invoke(self, messages, **kwargs): + return AIMessage( + content="I'll help you with the weather and calculation.", + tool_calls=test_tool_calls, + ) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + rails = RunnableRails(config, llm=MockLLMWithMultipleTools()) + result = await rails.ainvoke( + HumanMessage(content="What's the weather in NYC and what's 2+2?") + ) + + assert result.tool_calls is not None + assert len(result.tool_calls) == 2 + assert result.tool_calls[0]["name"] == "get_weather" + assert result.tool_calls[1]["name"] == "calculate" + assert result.content == "" + + +@pytest.mark.asyncio +async def test_passthrough_mode_no_tool_calls(): + config = RailsConfig.from_content(config={"models": [], "passthrough": True}) + + class MockLLMNoTools: + def invoke(self, messages, **kwargs): + return AIMessage(content="I can help with general questions.") + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + rails = RunnableRails(config, llm=MockLLMNoTools()) + result = await rails.ainvoke(HumanMessage(content="Hello")) + + assert result.tool_calls is None or result.tool_calls == [] + assert result.content == "I can help with general questions." + + +@pytest.mark.asyncio +async def test_passthrough_mode_empty_tool_calls(): + config = RailsConfig.from_content(config={"models": [], "passthrough": True}) + + class MockLLMEmptyTools: + def invoke(self, messages, **kwargs): + return AIMessage(content="No tools needed.", tool_calls=[]) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + rails = RunnableRails(config, llm=MockLLMEmptyTools()) + result = await rails.ainvoke(HumanMessage(content="Simple question")) + + assert result.tool_calls == [] + assert result.content == "No tools needed." + + +@pytest.mark.asyncio +async def test_tool_calls_with_prompt_response(): + test_tool_calls = [ + { + "name": "search", + "args": {"query": "latest news"}, + "id": "call_prompt", + "type": "tool_call", + } + ] + + config = RailsConfig.from_content(config={"models": [], "passthrough": True}) + + class MockLLMPromptMode: + def invoke(self, messages, **kwargs): + return AIMessage(content="", tool_calls=test_tool_calls) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + rails = RunnableRails(config, llm=MockLLMPromptMode()) + result = await rails.ainvoke(HumanMessage(content="Get me the latest news")) + + assert result.tool_calls is not None + assert len(result.tool_calls) == 1 + assert result.tool_calls[0]["name"] == "search" + assert result.tool_calls[0]["args"]["query"] == "latest news" + + +@pytest.mark.asyncio +async def test_tool_calls_preserve_metadata(): + test_tool_calls = [ + { + "name": "preserve_test", + "args": {"data": "preserved"}, + "id": "call_preserve", + "type": "tool_call", + } + ] + + config = RailsConfig.from_content(config={"models": [], "passthrough": True}) + + class MockLLMWithMetadata: + def invoke(self, messages, **kwargs): + msg = AIMessage( + content="Processing with metadata.", tool_calls=test_tool_calls + ) + msg.response_metadata = {"model": "test-model", "usage": {"tokens": 50}} + return msg + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + rails = RunnableRails(config, llm=MockLLMWithMetadata()) + result = await rails.ainvoke(HumanMessage(content="Process with metadata")) + + assert result.tool_calls is not None + assert result.tool_calls[0]["name"] == "preserve_test" + assert result.content == "" + assert hasattr(result, "response_metadata") + + +@pytest.mark.asyncio +async def test_tool_output_rails_blocking_behavior(): + dangerous_tool_calls = [ + { + "name": "dangerous_exec", + "args": {"command": "rm -rf /"}, + "id": "call_dangerous_exec", + "type": "tool_call", + } + ] + + config = RailsConfig.from_content( + """ + define subflow validate tool parameters + $valid = execute validate_tool_parameters(tool_calls=$tool_calls) + + if not $valid + bot refuse dangerous tool parameters + abort + + define bot refuse dangerous tool parameters + "Tool blocked for security reasons." + """, + """ + models: [] + passthrough: true + rails: + tool_output: + flows: + - validate tool parameters + """, + ) + + class MockLLMDangerousExec: + def invoke(self, messages, **kwargs): + return AIMessage(content="", tool_calls=dangerous_tool_calls) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + rails = RunnableRails(config, llm=MockLLMDangerousExec()) + + rails.rails.runtime.register_action( + validate_tool_parameters, name="validate_tool_parameters" + ) + rails.rails.runtime.register_action( + self_check_tool_calls, name="self_check_tool_calls" + ) + result = await rails.ainvoke(HumanMessage(content="Execute dangerous command")) + + assert "security reasons" in result.content + assert result.tool_calls is not None + assert result.tool_calls[0]["name"] == "dangerous_exec" + assert "rm -rf" in result.tool_calls[0]["args"]["command"] + + +@pytest.mark.asyncio +async def test_complex_tool_calls_integration(): + complex_tool_calls = [ + { + "name": "search_database", + "args": {"table": "users", "query": "active=true"}, + "id": "call_db_search", + "type": "tool_call", + }, + { + "name": "format_results", + "args": {"format": "json", "limit": 10}, + "id": "call_format", + "type": "tool_call", + }, + ] + + config = RailsConfig.from_content(config={"models": [], "passthrough": True}) + + class MockLLMComplexTools: + def invoke(self, messages, **kwargs): + return AIMessage( + content="I'll search the database and format the results.", + tool_calls=complex_tool_calls, + ) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + rails = RunnableRails(config, llm=MockLLMComplexTools()) + result = await rails.ainvoke( + HumanMessage(content="Find active users and format as JSON") + ) + + assert result.tool_calls is not None + assert len(result.tool_calls) == 2 + + db_call = result.tool_calls[0] + assert db_call["name"] == "search_database" + assert db_call["args"]["table"] == "users" + assert db_call["args"]["query"] == "active=true" + + format_call = result.tool_calls[1] + assert format_call["name"] == "format_results" + assert format_call["args"]["format"] == "json" + assert format_call["args"]["limit"] == 10 + + assert result.content == "" diff --git a/tests/test_tool_output_rails.py b/tests/test_tool_output_rails.py new file mode 100644 index 000000000..7f1d963c1 --- /dev/null +++ b/tests/test_tool_output_rails.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for tool output rails (Phase 2) functionality.""" + +from unittest.mock import patch + +import pytest + +from nemoguardrails import RailsConfig +from nemoguardrails.actions import action +from tests.utils import TestChat + + +@action(is_system_action=True) +async def validate_tool_parameters(tool_calls, context=None, **kwargs): + """Test implementation of tool parameter validation.""" + tool_calls = tool_calls or (context.get("tool_calls", []) if context else []) + + dangerous_patterns = ["eval", "exec", "system", "../", "rm -", "DROP", "DELETE"] + + for tool_call in tool_calls: + args = tool_call.get("args", {}) + for param_value in args.values(): + if isinstance(param_value, str): + if any( + pattern.lower() in param_value.lower() + for pattern in dangerous_patterns + ): + return False + return True + + +@action(is_system_action=True) +async def self_check_tool_calls(tool_calls, context=None, **kwargs): + """Test implementation of tool call validation.""" + tool_calls = tool_calls or (context.get("tool_calls", []) if context else []) + + return all( + isinstance(call, dict) and "name" in call and "id" in call + for call in tool_calls + ) + + +@pytest.mark.asyncio +async def test_tool_output_rails_basic(): + """Test basic tool output rails functionality.""" + + test_tool_calls = [ + { + "name": "allowed_tool", + "args": {"param": "safe_value"}, + "id": "call_safe", + "type": "tool_call", + } + ] + + # Config with tool output rails + config = RailsConfig.from_content( + """ + define subflow self check tool calls + $allowed = execute self_check_tool_calls(tool_calls=$tool_calls) + + if not $allowed + bot refuse tool execution + abort + + define bot refuse tool execution + "I cannot execute this tool request due to policy restrictions." + """, + """ + models: [] + passthrough: true + rails: + tool_output: + flows: + - self check tool calls + """, + ) + + with patch( + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = test_tool_calls + + chat = TestChat(config, llm_completions=[""]) + + chat.app.runtime.register_action( + validate_tool_parameters, name="validate_tool_parameters" + ) + chat.app.runtime.register_action( + self_check_tool_calls, name="self_check_tool_calls" + ) + + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "Use allowed tool"}] + ) + + # Tool should be allowed through + assert result["tool_calls"] is not None + assert result["tool_calls"][0]["name"] == "allowed_tool" + + +@pytest.mark.asyncio +async def test_tool_output_rails_blocking(): + """Test that tool output rails can block dangerous tools.""" + + test_tool_calls = [ + { + "name": "dangerous_tool", + "args": {"param": "eval('malicious code')"}, + "id": "call_bad", + "type": "tool_call", + } + ] + + # Config with tool parameter validation + config = RailsConfig.from_content( + """ + define subflow validate tool parameters + $valid = execute validate_tool_parameters(tool_calls=$tool_calls) + + if not $valid + bot refuse dangerous tool parameters + abort + + define bot refuse dangerous tool parameters + "I cannot execute this tool request because the parameters may be unsafe." + """, + """ + models: [] + passthrough: true + rails: + tool_output: + flows: + - validate tool parameters + """, + ) + + # Create a mock LLM that returns tool calls + class MockLLMWithDangerousTool: + def invoke(self, messages, **kwargs): + from langchain_core.messages import AIMessage + + return AIMessage(content="", tool_calls=test_tool_calls) + + async def ainvoke(self, messages, **kwargs): + return self.invoke(messages, **kwargs) + + from langchain_core.messages import HumanMessage + + from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails + + rails = RunnableRails(config, llm=MockLLMWithDangerousTool()) + + rails.rails.runtime.register_action( + validate_tool_parameters, name="validate_tool_parameters" + ) + rails.rails.runtime.register_action( + self_check_tool_calls, name="self_check_tool_calls" + ) + + result = await rails.ainvoke(HumanMessage(content="Use dangerous tool")) + + assert "parameters may be unsafe" in result.content + + +@pytest.mark.asyncio +async def test_multiple_tool_output_rails(): + """Test multiple tool output rails working together.""" + + test_tool_calls = [ + { + "name": "test_tool", + "args": {"param": "safe"}, + "id": "call_test", + "type": "tool_call", + } + ] + + config = RailsConfig.from_content( + """ + define subflow self check tool calls + $allowed = execute self_check_tool_calls(tool_calls=$tool_calls) + if not $allowed + bot refuse tool execution + abort + + define subflow validate tool parameters + $valid = execute validate_tool_parameters(tool_calls=$tool_calls) + if not $valid + bot refuse dangerous tool parameters + abort + + define bot refuse tool execution + "Tool not allowed." + + define bot refuse dangerous tool parameters + "Parameters unsafe." + """, + """ + models: [] + passthrough: true + rails: + tool_output: + flows: + - self check tool calls + - validate tool parameters + """, + ) + + with patch( + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = test_tool_calls + + chat = TestChat(config, llm_completions=[""]) + + chat.app.runtime.register_action( + validate_tool_parameters, name="validate_tool_parameters" + ) + chat.app.runtime.register_action( + self_check_tool_calls, name="self_check_tool_calls" + ) + + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "Use test tool"}] + ) + + assert result["tool_calls"] is not None + assert result["tool_calls"][0]["name"] == "test_tool"