diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index f54761aea..5fca2b6bf 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -591,7 +591,7 @@ async def generate_user_intent( if tool_calls: output_events.append( - new_event_dict("BotToolCall", tool_calls=tool_calls) + new_event_dict("BotToolCalls", tool_calls=tool_calls) ) else: output_events.append(new_event_dict("BotMessage", text=text)) @@ -905,9 +905,23 @@ async def generate_bot_message( LLMCallInfo(task=Task.GENERATE_BOT_MESSAGE.value) ) - # We use the potentially updated $user_message. This means that even - # in passthrough mode, input rails can still alter the input. - prompt = context.get("user_message") + # In passthrough mode, we should use the full conversation history + # instead of just the last user message to preserve tool message context + raw_prompt = raw_llm_request.get() + + if raw_prompt is not None and isinstance(raw_prompt, list): + # Use the full conversation including tool messages + prompt = raw_prompt.copy() + + # Update the last user message if it was altered by input rails + user_message = context.get("user_message") + if user_message and prompt: + for i in reversed(range(len(prompt))): + if prompt[i]["role"] == "user": + prompt[i]["content"] = user_message + break + else: + prompt = context.get("user_message") generation_options: GenerationOptions = generation_options_var.get() with llm_params( diff --git a/nemoguardrails/actions/llm/utils.py b/nemoguardrails/actions/llm/utils.py index d357037da..9de7ef439 100644 --- a/nemoguardrails/actions/llm/utils.py +++ b/nemoguardrails/actions/llm/utils.py @@ -153,16 +153,23 @@ def _convert_messages_to_langchain_format(prompt: List[dict]) -> List: if msg_type == "user": messages.append(HumanMessage(content=msg["content"])) elif msg_type in ["bot", "assistant"]: - messages.append(AIMessage(content=msg["content"])) + tool_calls = msg.get("tool_calls") + if tool_calls: + messages.append( + AIMessage(content=msg["content"], tool_calls=tool_calls) + ) + else: + messages.append(AIMessage(content=msg["content"])) elif msg_type == "system": messages.append(SystemMessage(content=msg["content"])) elif msg_type == "tool": - messages.append( - ToolMessage( - content=msg["content"], - tool_call_id=msg.get("tool_call_id", ""), - ) + tool_message = ToolMessage( + content=msg["content"], + tool_call_id=msg.get("tool_call_id", ""), ) + if msg.get("name"): + tool_message.name = msg["name"] + messages.append(tool_message) else: raise ValueError(f"Unknown message type {msg_type}") @@ -674,16 +681,16 @@ def get_and_clear_tool_calls_contextvar() -> Optional[list]: def extract_tool_calls_from_events(events: list) -> Optional[list]: - """Extract tool_calls from BotToolCall events. + """Extract tool_calls from BotToolCalls events. Args: events: List of events to search through Returns: - tool_calls if found in BotToolCall event, None otherwise + tool_calls if found in BotToolCalls event, None otherwise """ for event in events: - if event.get("type") == "BotToolCall": + if event.get("type") == "BotToolCalls": return event.get("tool_calls") return None diff --git a/nemoguardrails/integrations/langchain/runnable_rails.py b/nemoguardrails/integrations/langchain/runnable_rails.py index 764930e2a..836f259c9 100644 --- a/nemoguardrails/integrations/langchain/runnable_rails.py +++ b/nemoguardrails/integrations/langchain/runnable_rails.py @@ -26,6 +26,7 @@ BaseMessage, HumanMessage, SystemMessage, + ToolMessage, ) from langchain_core.prompt_values import ChatPromptValue, StringPromptValue from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable @@ -231,11 +232,23 @@ def _create_passthrough_messages(self, _input) -> List[Dict[str, Any]]: def _message_to_dict(self, msg: BaseMessage) -> Dict[str, Any]: """Convert a BaseMessage to dictionary format.""" if isinstance(msg, AIMessage): - return {"role": "assistant", "content": msg.content} + result = {"role": "assistant", "content": msg.content} + if hasattr(msg, "tool_calls") and msg.tool_calls: + result["tool_calls"] = msg.tool_calls + return result elif isinstance(msg, HumanMessage): return {"role": "user", "content": msg.content} elif isinstance(msg, SystemMessage): return {"role": "system", "content": msg.content} + elif isinstance(msg, ToolMessage): + result = { + "role": "tool", + "content": msg.content, + "tool_call_id": msg.tool_call_id, + } + if hasattr(msg, "name") and msg.name: + result["name"] = msg.name + return result else: # Handle other message types role = getattr(msg, "type", "user") return {"role": role, "content": msg.content} diff --git a/nemoguardrails/rails/llm/llm_flows.co b/nemoguardrails/rails/llm/llm_flows.co index 9c4d87372..4cbedfe57 100644 --- a/nemoguardrails/rails/llm/llm_flows.co +++ b/nemoguardrails/rails/llm/llm_flows.co @@ -106,7 +106,7 @@ define parallel extension flow process bot tool call """Processes tool calls from the bot.""" priority 100 - event BotToolCall + event BotToolCalls $tool_calls = $event.tool_calls @@ -130,6 +130,40 @@ define parallel extension flow process bot tool call create event StartToolCallBotAction(tool_calls=$tool_calls) +define parallel flow process user tool messages + """Run all the tool input rails on the tool messages.""" + priority 200 + event UserToolMessages + + $tool_messages = $event["tool_messages"] + + # If we have tool input rails, we run them, otherwise we just create the user message event + if $config.rails.tool_input.flows + # If we have generation options, we make sure the tool input rails are enabled. + $tool_input_enabled = True + if $generation_options is not None + if $generation_options.rails.tool_input == False + $tool_input_enabled = False + + if $tool_input_enabled: + create event StartToolInputRails + event StartToolInputRails + + $i = 0 + while $i < len($tool_messages) + $tool_message = $tool_messages[$i].content + $tool_name = $tool_messages[$i].name + if "tool_call_id" in $tool_messages[$i] + $tool_call_id = $tool_messages[$i].tool_call_id + else + $tool_call_id = "" + + do run tool input rails + $i = $i + 1 + + create event ToolInputRailsFinished + event ToolInputRailsFinished + define parallel extension flow process bot message """Runs the output rails on a bot message.""" priority 100 @@ -214,3 +248,24 @@ define subflow run tool output rails # If all went smooth, we remove it. $triggered_tool_output_rail = None + +define subflow run tool input rails + """Runs all the tool input rails in a sequential order.""" + $tool_input_flows = $config.rails.tool_input.flows + + $i = 0 + while $i < len($tool_input_flows) + # We set the current rail as being triggered. + $triggered_tool_input_rail = $tool_input_flows[$i] + + create event StartToolInputRail(flow_id=$triggered_tool_input_rail) + event StartToolInputRail + + do $tool_input_flows[$i] + $i = $i + 1 + + create event ToolInputRailFinished(flow_id=$triggered_tool_input_rail) + event ToolInputRailFinished + + # If all went smooth, we remove it. + $triggered_tool_input_rail = None diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 4d205ef9b..0b67802bb 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -747,19 +747,24 @@ def _get_events_for_messages(self, messages: List[dict], state: Any): ) elif msg["role"] == "assistant": - action_uid = new_uuid() - start_event = new_event_dict( - "StartUtteranceBotAction", - script=msg["content"], - action_uid=action_uid, - ) - finished_event = new_event_dict( - "UtteranceBotActionFinished", - final_script=msg["content"], - is_success=True, - action_uid=action_uid, - ) - events.extend([start_event, finished_event]) + if msg.get("tool_calls"): + events.append( + {"type": "BotToolCalls", "tool_calls": msg["tool_calls"]} + ) + else: + action_uid = new_uuid() + start_event = new_event_dict( + "StartUtteranceBotAction", + script=msg["content"], + action_uid=action_uid, + ) + finished_event = new_event_dict( + "UtteranceBotActionFinished", + final_script=msg["content"], + is_success=True, + action_uid=action_uid, + ) + events.extend([start_event, finished_event]) elif msg["role"] == "context": events.append({"type": "ContextUpdate", "data": msg["content"]}) elif msg["role"] == "event": @@ -767,6 +772,49 @@ def _get_events_for_messages(self, messages: List[dict], state: Any): elif msg["role"] == "system": # Handle system messages - convert them to SystemMessage events events.append({"type": "SystemMessage", "content": msg["content"]}) + elif msg["role"] == "tool": + # For the last tool message, create grouped tool event and synthetic UserMessage + if idx == len(messages) - 1: + # Find the original user message for response generation + user_message = None + for prev_msg in reversed(messages[:idx]): + if prev_msg["role"] == "user": + user_message = prev_msg["content"] + break + + if user_message: + # If tool input rails are configured, group all tool messages + if self.config.rails.tool_input.flows: + # Collect all tool messages for grouped processing + tool_messages = [] + for tool_idx in range(len(messages)): + if messages[tool_idx]["role"] == "tool": + tool_messages.append( + { + "content": messages[tool_idx][ + "content" + ], + "name": messages[tool_idx].get( + "name", "unknown" + ), + "tool_call_id": messages[tool_idx].get( + "tool_call_id", "" + ), + } + ) + + events.append( + { + "type": "UserToolMessages", + "tool_messages": tool_messages, + } + ) + + else: + events.append( + {"type": "UserMessage", "text": user_message} + ) + else: for idx in range(len(messages)): msg = messages[idx] diff --git a/tests/input_tool_rails_actions.py b/tests/input_tool_rails_actions.py new file mode 100644 index 000000000..78fb90074 --- /dev/null +++ b/tests/input_tool_rails_actions.py @@ -0,0 +1,204 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for input tool rails tests. + +This module contains the utility functions that were previously implemented as +fake test functions in the test file. They provide the actual implementation +for tool input validation, safety checking, and sanitization. +""" + +import logging +import re +from typing import Optional + +from nemoguardrails.actions import action + +log = logging.getLogger(__name__) + + +@action(is_system_action=True) +async def self_check_tool_input( + tool_message: str = None, + tool_name: str = None, + tool_call_id: str = None, + context: Optional[dict] = None, + **kwargs, +) -> bool: + """Test implementation of basic tool input validation. + + This action performs basic validation of tool results coming from tools: + - Checks if tool results are valid and safe + - Validates the structure and content + - Performs basic security checks on tool responses + + Args: + tool_message: The content returned by the tool + tool_name: Name of the tool that returned this result + tool_call_id: ID linking to the original tool call + context: Optional context information + + Returns: + bool: True if tool input is valid, False to block + """ + tool_message = tool_message or (context.get("tool_message") if context else "") + tool_name = tool_name or (context.get("tool_name") if context else "") + tool_call_id = tool_call_id or (context.get("tool_call_id") if context else "") + + config = context.get("config") if context else None + allowed_tools = getattr(config, "allowed_tools", None) if config else None + + log.debug(f"Validating tool input from {tool_name}: {tool_message[:100]}...") + + if allowed_tools and tool_name not in allowed_tools: + log.warning(f"Tool '{tool_name}' not in allowed tools list: {allowed_tools}") + return False + + if not tool_message: + log.warning(f"Empty tool message from {tool_name}") + return False + + if not tool_name: + log.warning("Tool message received without tool name") + return False + + if not tool_call_id: + log.warning(f"Tool message from {tool_name} missing tool_call_id") + return False + + max_length = getattr(config, "max_tool_message_length", 10000) if config else 10000 + if len(tool_message) > max_length: + log.warning( + f"Tool message from {tool_name} exceeds max length: {len(tool_message)} > {max_length}" + ) + return False + + return True + + +@action(is_system_action=True) +async def validate_tool_input_safety( + tool_message: str = None, + tool_name: str = None, + context: Optional[dict] = None, + **kwargs, +) -> bool: + """Test implementation of tool input safety validation. + + This action checks tool results for potentially dangerous content: + - Detects sensitive information patterns + - Flags potentially harmful content + - Prevents dangerous data from being processed + + Args: + tool_message: The content returned by the tool + tool_name: Name of the tool that returned this result + context: Optional context information + + Returns: + bool: True if tool input is safe, False to block + """ + tool_message = tool_message or (context.get("tool_message") if context else "") + tool_name = tool_name or (context.get("tool_name") if context else "") + + if not tool_message: + return True + + log.debug(f"Validating safety of tool input from {tool_name}") + + dangerous_patterns = [ + "password", + "secret", + "api_key", + "private_key", + "token", + "credential", + " str: + """Test implementation of tool input sanitization. + + This action cleans and sanitizes tool results: + - Removes or masks sensitive information + - Truncates overly long responses + - Escapes potentially dangerous content + + Args: + tool_message: The content returned by the tool + tool_name: Name of the tool that returned this result + context: Optional context information + + Returns: + str: Sanitized tool message content + """ + tool_message = tool_message or (context.get("tool_message") if context else "") + tool_name = tool_name or (context.get("tool_name") if context else "") + + if not tool_message: + return tool_message + + log.debug(f"Sanitizing tool input from {tool_name}") + + sanitized = tool_message + + sanitized = re.sub( + r'(api[_-]?key|token|secret)["\']?\s*[:=]\s*["\']?([a-zA-Z0-9]{16,})["\']?', + r"\1: [REDACTED]", + sanitized, + flags=re.IGNORECASE, + ) + + sanitized = re.sub( + r"([a-zA-Z0-9._%+-]+)@([a-zA-Z0-9.-]+\.[a-zA-Z]{2,})", r"[USER]@\2", sanitized + ) + + config = context.get("config") if context else None + max_length = getattr(config, "max_tool_message_length", 10000) if config else 10000 + + if len(sanitized) > max_length: + log.info( + f"Truncating tool response from {tool_name}: {len(sanitized)} -> {max_length}" + ) + sanitized = sanitized[: max_length - 50] + "... [TRUNCATED]" + + return sanitized diff --git a/tests/runnable_rails/test_runnable_rails.py b/tests/runnable_rails/test_runnable_rails.py index 55ddfd101..ce4e1bd51 100644 --- a/tests/runnable_rails/test_runnable_rails.py +++ b/tests/runnable_rails/test_runnable_rails.py @@ -327,7 +327,9 @@ def test_string_passthrough_mode_on_with_dialog_rails(): info = model_with_rails.rails.explain() assert len(info.llm_calls) == 2 - assert info.llm_calls[1].prompt == "The capital of France is " + # In passthrough mode with dialog rails, the second call should use the message format + # since RunnableRails converts StringPromptValue to message list, which gets formatted as "Human: ..." + assert info.llm_calls[1].prompt == "Human: The capital of France is " assert result == "Paris." diff --git a/tests/runnable_rails/test_tool_calling.py b/tests/runnable_rails/test_tool_calling.py index fb42f357c..04a7df391 100644 --- a/tests/runnable_rails/test_tool_calling.py +++ b/tests/runnable_rails/test_tool_calling.py @@ -264,7 +264,7 @@ async def ainvoke(self, messages, **kwargs): def test_bot_tool_call_event_creation(): - """Test that BotToolCall events are created instead of BotMessage when tool_calls exist.""" + """Test that BotToolCalls events are created instead of BotMessage when tool_calls exist.""" class MockLLMReturningToolCall: def invoke(self, messages, **kwargs): diff --git a/tests/test_bot_tool_call_events.py b/tests/test_bot_tool_call_events.py index 400432e55..36035292f 100644 --- a/tests/test_bot_tool_call_events.py +++ b/tests/test_bot_tool_call_events.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for BotToolCall event handling in NeMo Guardrails.""" +"""Tests for BotToolCalls event handling in NeMo Guardrails.""" from unittest.mock import patch @@ -25,7 +25,7 @@ @pytest.mark.asyncio async def test_bot_tool_call_event_creation(): - """Test that BotToolCall events are created when tool_calls are present.""" + """Test that BotToolCalls events are created when tool_calls are present.""" test_tool_calls = [ { @@ -55,7 +55,7 @@ async def test_bot_tool_call_event_creation(): @pytest.mark.asyncio async def test_bot_message_vs_bot_tool_call_event(): - """Test that regular text creates BotMessage, tool calls create BotToolCall.""" + """Test that regular text creates BotMessage, tool calls create BotToolCalls.""" config = RailsConfig.from_content(config={"models": [], "passthrough": True}) diff --git a/tests/test_input_tool_rails.py b/tests/test_input_tool_rails.py new file mode 100644 index 000000000..bd2ddb03c --- /dev/null +++ b/tests/test_input_tool_rails.py @@ -0,0 +1,953 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for input tool rails functionality. + +This module tests the input tool rails functionality introduced to validate and secure +tool inputs/results before they are processed by the LLM. Since the @nemoguardrails/library/tool_check/ +actions and flows will be removed, this test file implements similar test actions and flows +to ensure input tool rails work as expected. +""" + +import logging +from unittest.mock import patch + +import pytest + +from nemoguardrails import RailsConfig +from tests.input_tool_rails_actions import ( + sanitize_tool_input, + self_check_tool_input, + validate_tool_input_safety, +) +from tests.utils import TestChat + +log = logging.getLogger(__name__) + + +class TestInputToolRails: + """Test class for input tool rails functionality.""" + + @pytest.mark.asyncio + async def test_user_tool_messages_event_direct_processing(self): + """Test that UserToolMessages events work correctly when created directly. + + This tests the core tool input rails functionality by creating UserToolMessages + events directly, which should work according to the commit changes. + """ + config = RailsConfig.from_content( + """ + define flow self check tool input + $allowed = execute test_self_check_tool_input(tool_message=$tool_message, tool_name=$tool_name, tool_call_id=$tool_call_id) + + if not $allowed + bot refuse tool input + abort + + define bot refuse tool input + "Tool input validation failed via direct event." + """, + """ + models: [] + passthrough: true + rails: + tool_input: + flows: + - self check tool input + """, + ) + + chat = TestChat(config, llm_completions=["Should not be reached"]) + + chat.app.runtime.register_action( + self_check_tool_input, name="test_self_check_tool_input" + ) + + from nemoguardrails.utils import new_event_dict + + tool_messages = [ + { + "content": "Sunny, 22°C", + "name": "get_weather", + "tool_call_id": "call_weather_001", + } + ] + + events = [ + new_event_dict( + "UserToolMessages", + tool_messages=tool_messages, + ) + ] + result_events = await chat.app.runtime.generate_events(events) + + tool_input_rails_finished = any( + event.get("type") == "ToolInputRailsFinished" for event in result_events + ) + assert ( + tool_input_rails_finished + ), "Expected ToolInputRailsFinished event to be generated after successful tool input validation" + + invalid_tool_messages = [ + { + "content": "Sunny, 22°C", + "name": "get_weather", + } + ] + + invalid_events = [ + new_event_dict( + "UserToolMessages", + tool_messages=invalid_tool_messages, + ) + ] + invalid_result_events = await chat.app.runtime.generate_events(invalid_events) + + blocked_found = any( + event.get("type") == "BotMessage" + and "validation failed" in event.get("text", "") + for event in invalid_result_events + ) + assert ( + blocked_found + ), f"Expected tool input to be blocked, got events: {invalid_result_events}" + + @pytest.mark.asyncio + async def test_message_to_event_conversion_fixed(self): + """Test that message-to-event conversion for tool messages now works correctly. + + This test verifies that the automatic conversion from conversation messages + to UserToolMessages events is working correctly after the fix. + """ + config = RailsConfig.from_content( + """ + define flow self check tool input + $allowed = execute test_self_check_tool_input(tool_message=$tool_message, tool_name=$tool_name, tool_call_id=$tool_call_id) + + if not $allowed + bot refuse tool input + abort + + define bot refuse tool input + "Tool input blocked via message processing." + """, + """ + models: [] + passthrough: true + rails: + tool_input: + flows: + - self check tool input + """, + ) + + chat = TestChat(config, llm_completions=["Normal LLM response"]) + + chat.app.runtime.register_action( + self_check_tool_input, name="test_self_check_tool_input" + ) + + messages = [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "name": "get_weather", + "args": {"city": "Paris"}, + "id": "call_weather_001", + "type": "tool_call", + } + ], + }, + { + "role": "tool", + "content": "", + "name": "get_weather", + "tool_call_id": "call_weather_001", + }, + ] + + result = await chat.app.generate_async(messages=messages) + + assert ( + "Tool input blocked" in result["content"] + ), f"Expected tool input to be blocked, got: {result['content']}" + + @pytest.mark.asyncio + async def test_tool_input_validation_blocking(self): + """Test that tool input validation can block invalid tool responses.""" + config = RailsConfig.from_content( + """ + define flow self check tool input + $allowed = execute test_self_check_tool_input(tool_message=$tool_message, tool_name=$tool_name, tool_call_id=$tool_call_id) + + if not $allowed + bot refuse tool input + abort + + define bot refuse tool input + "I cannot process this tool response due to validation issues." + """, + """ + models: [] + passthrough: true + rails: + tool_input: + flows: + - self check tool input + """, + ) + + chat = TestChat(config, llm_completions=[""]) + + chat.app.runtime.register_action( + self_check_tool_input, name="test_self_check_tool_input" + ) + + messages = [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "name": "get_weather", + "args": {"city": "Paris"}, + "id": "call_weather_001", + "type": "tool_call", + } + ], + }, + { + "role": "tool", + "content": "Sunny, 22°C", + "name": "get_weather", + }, + ] + + result = await chat.app.generate_async(messages=messages) + + assert ( + "validation issues" in result["content"] + ), f"Expected validation to block missing tool_call_id, got: {result['content']}" + + @pytest.mark.asyncio + async def test_tool_input_safety_validation(self): + """Test tool input safety validation blocks dangerous content.""" + config = RailsConfig.from_content( + """ + define flow validate tool input safety + $safe = execute test_validate_tool_input_safety(tool_message=$tool_message, tool_name=$tool_name) + + if not $safe + bot refuse unsafe tool input + abort + + define bot refuse unsafe tool input + "I cannot process this tool response due to safety concerns." + """, + """ + models: [] + passthrough: true + rails: + tool_input: + flows: + - validate tool input safety + """, + ) + + chat = TestChat(config, llm_completions=[""]) + + chat.app.runtime.register_action( + validate_tool_input_safety, name="test_validate_tool_input_safety" + ) + + messages = [ + {"role": "user", "content": "Get my credentials"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "name": "get_credentials", + "args": {}, + "id": "call_creds_001", + "type": "tool_call", + } + ], + }, + { + "role": "tool", + "content": "Your api_key is sk-1234567890abcdef and password is secret123", + "name": "get_credentials", + "tool_call_id": "call_creds_001", + }, + ] + + result = await chat.app.generate_async(messages=messages) + + assert "safety concerns" in result["content"] + + @pytest.mark.asyncio + async def test_tool_input_sanitization(self): + """Test tool input sanitization processes sensitive information without blocking. + + This test verifies that the sanitization rail runs on tool inputs containing + sensitive data and processes them appropriately without blocking the conversation. + """ + config = RailsConfig.from_content( + """ + define flow sanitize tool input + $sanitized = execute test_sanitize_tool_input(tool_message=$tool_message, tool_name=$tool_name) + $tool_message = $sanitized + + define flow self check tool input + $allowed = execute test_self_check_tool_input(tool_message=$tool_message, tool_name=$tool_name, tool_call_id=$tool_call_id) + if not $allowed + bot refuse tool input + abort + + define bot refuse tool input + "I cannot process this tool response." + """, + """ + models: [] + passthrough: true + rails: + tool_input: + flows: + - sanitize tool input + - self check tool input + """, + ) + + chat = TestChat( + config, + llm_completions=["I found your account information from the database."], + ) + + chat.app.runtime.register_action( + sanitize_tool_input, name="test_sanitize_tool_input" + ) + chat.app.runtime.register_action( + self_check_tool_input, name="test_self_check_tool_input" + ) + + messages = [ + {"role": "user", "content": "Look up my account"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "name": "lookup_account", + "args": {}, + "id": "call_lookup_001", + "type": "tool_call", + } + ], + }, + { + "role": "tool", + "content": "User email: john.doe@example.com, API token = abcd1234567890xyzABC", + "name": "lookup_account", + "tool_call_id": "call_lookup_001", + }, + ] + + sanitized_result = await sanitize_tool_input( + tool_message="User email: john.doe@example.com, API token = abcd1234567890xyzABC", + tool_name="lookup_account", + ) + + assert ( + "[USER]@example.com" in sanitized_result + ), f"Email not sanitized: {sanitized_result}" + assert ( + "[REDACTED]" in sanitized_result + ), f"API token not sanitized: {sanitized_result}" + assert ( + "john.doe" not in sanitized_result + ), f"Username not masked: {sanitized_result}" + assert ( + "abcd1234567890xyzABC" not in sanitized_result + ), f"API token not masked: {sanitized_result}" + + result = await chat.app.generate_async(messages=messages) + + assert ( + "cannot process" not in result["content"].lower() + ), f"Unexpected blocking: {result['content']}" + + @pytest.mark.asyncio + async def test_multiple_tool_input_rails(self): + """Test multiple tool input rails working together.""" + config = RailsConfig.from_content( + """ + define flow self check tool input + $allowed = execute test_self_check_tool_input(tool_message=$tool_message, tool_name=$tool_name, tool_call_id=$tool_call_id) + if not $allowed + bot refuse tool input + abort + + define flow validate tool input safety + $safe = execute test_validate_tool_input_safety(tool_message=$tool_message, tool_name=$tool_name) + if not $safe + bot refuse unsafe tool input + abort + + define bot refuse tool input + "Tool validation failed." + + define bot refuse unsafe tool input + "Tool safety check failed." + """, + """ + models: [] + passthrough: true + rails: + tool_input: + flows: + - self check tool input + - validate tool input safety + """, + ) + + chat = TestChat( + config, + llm_completions=["The weather information shows it's sunny."], + ) + + chat.app.runtime.register_action( + self_check_tool_input, name="test_self_check_tool_input" + ) + chat.app.runtime.register_action( + validate_tool_input_safety, name="test_validate_tool_input_safety" + ) + + messages = [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "name": "get_weather", + "args": {"city": "Paris"}, + "id": "call_weather_001", + "type": "tool_call", + } + ], + }, + { + "role": "tool", + "content": "Sunny, 22°C in Paris", + "name": "get_weather", + "tool_call_id": "call_weather_001", + }, + ] + + from nemoguardrails.utils import new_event_dict + + events = [ + new_event_dict( + "UserToolMessages", + tool_messages=[ + { + "content": "Sunny, 22°C", + "name": "get_weather", + "tool_call_id": "call_weather_001", + } + ], + ) + ] + + result_events = await chat.app.runtime.generate_events(events) + + safety_rail_finished = any( + event.get("type") == "ToolInputRailFinished" + and event.get("flow_id") == "validate tool input safety" + for event in result_events + ) + validation_rail_finished = any( + event.get("type") == "ToolInputRailFinished" + and event.get("flow_id") == "self check tool input" + for event in result_events + ) + + assert safety_rail_finished, "Safety rail should have completed" + assert validation_rail_finished, "Validation rail should have completed" + + @pytest.mark.asyncio + async def test_multiple_tool_messages_processing(self): + """Test processing multiple tool messages in UserToolMessages event.""" + config = RailsConfig.from_content( + """ + define flow self check tool input + $allowed = execute test_self_check_tool_input(tool_message=$tool_message, tool_name=$tool_name, tool_call_id=$tool_call_id) + if not $allowed + bot refuse tool input + abort + + define bot refuse tool input + "Tool validation failed." + """, + """ + models: + - type: main + engine: mock + model: test-model + rails: + tool_input: + flows: + - self check tool input + """, + ) + + chat = TestChat( + config, + llm_completions=[ + "The weather is sunny in Paris and AAPL stock is at $150.25." + ], + ) + + chat.app.runtime.register_action( + self_check_tool_input, name="test_self_check_tool_input" + ) + + messages = [ + { + "role": "user", + "content": "Get weather for Paris and stock price for AAPL", + }, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "name": "get_weather", + "args": {"city": "Paris"}, + "id": "call_weather_001", + "type": "tool_call", + }, + { + "name": "get_stock_price", + "args": {"symbol": "AAPL"}, + "id": "call_stock_001", + "type": "tool_call", + }, + ], + }, + { + "role": "tool", + "content": "Sunny, 22°C", + "name": "get_weather", + "tool_call_id": "call_weather_001", + }, + { + "role": "tool", + "content": "$150.25", + "name": "get_stock_price", + "tool_call_id": "call_stock_001", + }, + ] + + result = await chat.app.generate_async(messages=messages) + + assert ( + "validation issues" not in result["content"] + ), f"Unexpected validation block: {result['content']}" + + @pytest.mark.asyncio + async def test_tool_input_rails_with_allowed_tools_config(self): + """Test tool input rails respecting allowed tools configuration.""" + + class CustomConfig: + def __init__(self): + self.allowed_tools = ["get_weather", "get_time"] + self.max_tool_message_length = 10000 + + config = RailsConfig.from_content( + """ + define flow self check tool input + $allowed = execute test_self_check_tool_input(tool_message=$tool_message, tool_name=$tool_name, tool_call_id=$tool_call_id) + if not $allowed + bot refuse tool input + abort + + define bot refuse tool input + "Tool not allowed." + """, + """ + models: + - type: main + engine: mock + model: test-model + rails: + tool_input: + flows: + - self check tool input + """, + ) + + chat = TestChat(config, llm_completions=[""]) + + async def patched_self_check_tool_input(*args, **kwargs): + context = kwargs.get("context", {}) + context["config"] = CustomConfig() + kwargs["context"] = context + return await self_check_tool_input(*args, **kwargs) + + chat.app.runtime.register_action( + patched_self_check_tool_input, name="test_self_check_tool_input" + ) + + messages = [ + {"role": "user", "content": "Execute dangerous operation"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "name": "dangerous_tool", + "args": {}, + "id": "call_danger_001", + "type": "tool_call", + } + ], + }, + { + "role": "tool", + "content": "Operation completed", + "name": "dangerous_tool", + "tool_call_id": "call_danger_001", + }, + ] + + result = await chat.app.generate_async(messages=messages) + + assert "not allowed" in result["content"] + + @pytest.mark.asyncio + async def test_oversized_tool_message_blocked(self): + """Test that oversized tool messages are blocked by validation.""" + + class CustomConfig: + def __init__(self): + self.max_tool_message_length = 50 + + config = RailsConfig.from_content( + """ + define flow self check tool input + $allowed = execute test_self_check_tool_input(tool_message=$tool_message, tool_name=$tool_name, tool_call_id=$tool_call_id) + if not $allowed + bot refuse tool input + abort + + define bot refuse tool input + "Tool response too long." + """, + """ + models: + - type: main + engine: mock + model: test-model + rails: + tool_input: + flows: + - self check tool input + """, + ) + + chat = TestChat(config, llm_completions=[""]) + + async def patched_self_check_tool_input(*args, **kwargs): + context = kwargs.get("context", {}) + context["config"] = CustomConfig() + kwargs["context"] = context + return await self_check_tool_input(*args, **kwargs) + + chat.app.runtime.register_action( + patched_self_check_tool_input, name="test_self_check_tool_input" + ) + + large_message = "This is a very long tool response that exceeds the maximum allowed length and should be blocked by the validation" + + messages = [ + {"role": "user", "content": "Get large data"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "name": "get_large_data", + "args": {}, + "id": "call_large_001", + "type": "tool_call", + } + ], + }, + { + "role": "tool", + "content": large_message, + "name": "get_large_data", + "tool_call_id": "call_large_001", + }, + ] + + result = await chat.app.generate_async(messages=messages) + + assert "too long" in result["content"] + + +class TestBotToolCallsEventChanges: + """Test the changes from BotToolCall to BotToolCalls event.""" + + @pytest.mark.asyncio + async def test_bot_tool_calls_event_generated(self): + """Test that BotToolCalls events are generated (not BotToolCall).""" + test_tool_calls = [ + { + "name": "test_function", + "args": {"param1": "value1"}, + "id": "call_12345", + "type": "tool_call", + } + ] + + with patch( + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = test_tool_calls + + config = RailsConfig.from_content( + config={"models": [], "passthrough": True} + ) + chat = TestChat(config, llm_completions=[""]) + + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "Test"}] + ) + + assert result["tool_calls"] is not None + assert len(result["tool_calls"]) == 1 + assert result["tool_calls"][0]["name"] == "test_function" + + @pytest.mark.asyncio + async def test_multiple_tool_calls_in_bot_tool_calls_event(self): + """Test that multiple tool calls are handled in BotToolCalls event.""" + test_tool_calls = [ + { + "name": "tool_one", + "args": {"param": "first"}, + "id": "call_one", + "type": "tool_call", + }, + { + "name": "tool_two", + "args": {"param": "second"}, + "id": "call_two", + "type": "tool_call", + }, + ] + + with patch( + "nemoguardrails.actions.llm.utils.get_and_clear_tool_calls_contextvar" + ) as mock_get_clear: + mock_get_clear.return_value = test_tool_calls + + config = RailsConfig.from_content( + config={"models": [], "passthrough": True} + ) + chat = TestChat(config, llm_completions=[""]) + + result = await chat.app.generate_async( + messages=[{"role": "user", "content": "Execute multiple tools"}] + ) + + assert result["tool_calls"] is not None + assert len(result["tool_calls"]) == 2 + assert result["tool_calls"][0]["name"] == "tool_one" + assert result["tool_calls"][1]["name"] == "tool_two" + + +class TestUserToolMessagesEventProcessing: + """Test the new UserToolMessages event processing.""" + + @pytest.mark.asyncio + async def test_user_tool_messages_validation_failure(self): + """Test that UserToolMessages processing can fail validation.""" + config = RailsConfig.from_content( + """ + define flow self check tool input + $allowed = execute test_self_check_tool_input(tool_message=$tool_message, tool_name=$tool_name, tool_call_id=$tool_call_id) + if not $allowed + bot refuse tool input + abort + + define bot refuse tool input + "Tool input validation failed." + """, + """ + models: + - type: main + engine: mock + model: test-model + rails: + tool_input: + flows: + - self check tool input + """, + ) + + chat = TestChat(config, llm_completions=[""]) + + chat.app.runtime.register_action( + self_check_tool_input, name="test_self_check_tool_input" + ) + + messages = [ + {"role": "user", "content": "Get weather and stock data"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "name": "get_weather", + "args": {"city": "Paris"}, + "id": "call_weather_001", + "type": "tool_call", + }, + { + "name": "get_stock_price", + "args": {"symbol": "AAPL"}, + "id": "call_stock_001", + "type": "tool_call", + }, + ], + }, + { + "role": "tool", + "content": "Sunny, 22°C", + "name": "get_weather", + "tool_call_id": "call_weather_001", + }, + { + "role": "tool", + "content": "$150.25", + "name": "get_stock_price", + }, + ] + + result = await chat.app.generate_async(messages=messages) + + from nemoguardrails.utils import new_event_dict + + invalid_events = [ + new_event_dict( + "UserToolMessages", + tool_messages=[ + { + "content": "$150.25", + "name": "get_stock_price", + } + ], + ) + ] + + invalid_result_events = await chat.app.runtime.generate_events(invalid_events) + + blocked_found = any( + event.get("type") == "BotMessage" + and "validation failed" in event.get("text", "") + for event in invalid_result_events + ) + assert ( + blocked_found + ), f"Expected tool input to be blocked, got events: {invalid_result_events}" + + +class TestInputToolRailsIntegration: + """Integration tests for input tool rails with the broader system.""" + + @pytest.mark.asyncio + async def test_input_tool_rails_disabled_generation_options(self): + """Test input tool rails can be disabled via generation options.""" + config = RailsConfig.from_content( + """ + define flow self check tool input + $allowed = execute test_self_check_tool_input(tool_message=$tool_message, tool_name=$tool_name, tool_call_id=$tool_call_id) + if not $allowed + bot refuse tool input + abort + + define bot refuse tool input + "Input validation blocked this." + """, + """ + models: [] + passthrough: true + rails: + tool_input: + flows: + - self check tool input + """, + ) + + chat = TestChat( + config, + llm_completions=["Weather processed without validation."], + ) + + chat.app.runtime.register_action( + self_check_tool_input, name="test_self_check_tool_input" + ) + + messages = [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "name": "get_weather", + "args": {"city": "Paris"}, + "id": "call_weather_001", + "type": "tool_call", + } + ], + }, + { + "role": "tool", + "content": "", + "name": "get_weather", + "tool_call_id": "call_weather_001", + }, + ] + + result = await chat.app.generate_async( + messages=messages, options={"rails": {"tool_input": False}} + ) + + content = result.response[0]["content"] if result.response else "" + assert ( + "Input validation blocked" not in content + ), f"Tool input rails should be disabled but got blocking: {content}" + + assert ( + "Weather processed without validation" in content + ), f"Expected LLM completion when tool input rails disabled: {content}" diff --git a/tests/test_tool_calling_passthrough_only.py b/tests/test_tool_calling_passthrough_only.py index 5791d5b0f..1087ebf37 100644 --- a/tests/test_tool_calling_passthrough_only.py +++ b/tests/test_tool_calling_passthrough_only.py @@ -109,7 +109,8 @@ def test_config_passthrough_false(self, config_no_passthrough): async def test_tool_calls_work_in_passthrough_mode( self, config_passthrough, mock_llm_with_tool_calls ): - """Test that tool calls create BotToolCall events in passthrough mode.""" + """Test that tool calls create BotToolCalls events in passthrough mode.""" + # Set up context with tool calls tool_calls = [ { "id": "call_123", @@ -135,7 +136,7 @@ async def test_tool_calls_work_in_passthrough_mode( ) assert len(result.events) == 1 - assert result.events[0]["type"] == "BotToolCall" + assert result.events[0]["type"] == "BotToolCalls" assert result.events[0]["tool_calls"] == tool_calls @pytest.mark.asyncio diff --git a/tests/test_tool_calls_event_extraction.py b/tests/test_tool_calls_event_extraction.py index 4a2d0f5fd..a53df31c8 100644 --- a/tests/test_tool_calls_event_extraction.py +++ b/tests/test_tool_calls_event_extraction.py @@ -165,7 +165,7 @@ async def test_llmrails_extracts_tool_calls_from_events(): ] mock_events = [ - {"type": "BotToolCall", "tool_calls": test_tool_calls, "uid": "test_uid"} + {"type": "BotToolCalls", "tool_calls": test_tool_calls, "uid": "test_uid"} ] from nemoguardrails.actions.llm.utils import extract_tool_calls_from_events