Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
875 changes: 764 additions & 111 deletions nemoguardrails/integrations/langchain/runnable_rails.py

Large diffs are not rendered by default.

152 changes: 152 additions & 0 deletions tests/runnable_rails/test_basic_operations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Tests for basic RunnableRails operations (invoke, async, batch, stream).
"""

import pytest
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough

from nemoguardrails import RailsConfig
from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails
from tests.utils import FakeLLM


def test_runnable_rails_basic():
"""Test basic functionality of updated RunnableRails."""
llm = FakeLLM(
responses=[
"Hello there! How can I help you today?",
]
)
config = RailsConfig.from_content(config={"models": []})
model_with_rails = RunnableRails(config, llm=llm)

result = model_with_rails.invoke("Hi there")

assert isinstance(result, str)
assert "Hello there" in result


@pytest.mark.asyncio
async def test_runnable_rails_async():
"""Test async functionality of updated RunnableRails."""
llm = FakeLLM(
responses=[
"Hello there! How can I help you today?",
]
)
config = RailsConfig.from_content(config={"models": []})
model_with_rails = RunnableRails(config, llm=llm)

result = await model_with_rails.ainvoke("Hi there")

assert isinstance(result, str)
assert "Hello there" in result


def test_runnable_rails_batch():
"""Test batch functionality of updated RunnableRails."""
llm = FakeLLM(
responses=[
"Response 1",
"Response 2",
]
)
config = RailsConfig.from_content(config={"models": []})
model_with_rails = RunnableRails(config, llm=llm)

results = model_with_rails.batch(["Question 1", "Question 2"])

assert len(results) == 2
assert results[0] == "Response 1"
assert results[1] == "Response 2"


def test_updated_runnable_rails_stream():
"""Test streaming functionality of updated RunnableRails."""
llm = FakeLLM(
responses=[
"Hello there!",
]
)
config = RailsConfig.from_content(config={"models": []})
model_with_rails = RunnableRails(config, llm=llm)

chunks = []
for chunk in model_with_rails.stream("Hi there"):
chunks.append(chunk)

assert len(chunks) == 2
assert chunks[0].content == "Hello "
assert chunks[1].content == "there!"


def test_runnable_rails_with_message_history():
"""Test handling of message history with updated RunnableRails."""
llm = FakeLLM(
responses=[
"Yes, Paris is the capital of France.",
]
)
config = RailsConfig.from_content(config={"models": []})
model_with_rails = RunnableRails(config, llm=llm)

history = [
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
HumanMessage(content="What's the capital of France?"),
]

result = model_with_rails.invoke(history)

assert isinstance(result, AIMessage)
assert "Paris" in result.content


def test_runnable_rails_with_chat_template():
"""Test updated RunnableRails with chat templates."""
llm = FakeLLM(
responses=[
"Yes, Paris is the capital of France.",
]
)
config = RailsConfig.from_content(config={"models": []})
model_with_rails = RunnableRails(config, llm=llm)

prompt = ChatPromptTemplate.from_messages(
[
MessagesPlaceholder(variable_name="history"),
("human", "{question}"),
]
)

chain = prompt | model_with_rails

result = chain.invoke(
{
"history": [
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
],
"question": "What's the capital of France?",
}
)

assert isinstance(result, AIMessage)
assert "Paris" in result.content
41 changes: 41 additions & 0 deletions tests/runnable_rails/test_batch_as_completed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for batch_as_completed methods."""

import pytest

from nemoguardrails import RailsConfig
from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails
from tests.utils import FakeLLM


@pytest.fixture
def rails():
"""Create a RunnableRails instance for testing."""
config = RailsConfig.from_content(config={"models": []})
llm = FakeLLM(responses=["response 1", "response 2", "response 3"])
return RunnableRails(config, llm=llm)


def test_batch_as_completed_exists(rails):
"""Test that batch_as_completed method exists."""
assert hasattr(rails, "batch_as_completed")


@pytest.mark.asyncio
async def test_abatch_as_completed_exists(rails):
"""Test that abatch_as_completed method exists."""
assert hasattr(rails, "abatch_as_completed")
147 changes: 147 additions & 0 deletions tests/runnable_rails/test_batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from langchain_core.messages import AIMessage, HumanMessage

from nemoguardrails import RailsConfig
from nemoguardrails.integrations.langchain.runnable_rails import RunnableRails
from tests.utils import FakeLLM


def test_batch_processing():
"""Test batch processing of multiple inputs."""
llm = FakeLLM(
responses=[
"Paris.",
"Rome.",
"Berlin.",
]
)
config = RailsConfig.from_content(config={"models": []})
model_with_rails = RunnableRails(config, llm=llm)

inputs = [
"What's the capital of France?",
"What's the capital of Italy?",
"What's the capital of Germany?",
]

results = model_with_rails.batch(inputs)

assert len(results) == 3
assert results[0] == "Paris."
assert results[1] == "Rome."
assert results[2] == "Berlin."


@pytest.mark.asyncio
async def test_abatch_processing():
"""Test async batch processing of multiple inputs."""
llm = FakeLLM(
responses=[
"Paris.",
"Rome.",
"Berlin.",
]
)
config = RailsConfig.from_content(config={"models": []})
model_with_rails = RunnableRails(config, llm=llm)

inputs = [
"What's the capital of France?",
"What's the capital of Italy?",
"What's the capital of Germany?",
]

results = await model_with_rails.abatch(inputs)

assert len(results) == 3
assert results[0] == "Paris."
assert results[1] == "Rome."
assert results[2] == "Berlin."


def test_batch_with_different_input_types():
"""Test batch processing with different input types."""
llm = FakeLLM(
responses=[
"Paris.",
"Rome.",
"Berlin.",
]
)
config = RailsConfig.from_content(config={"models": []})
model_with_rails = RunnableRails(config, llm=llm)

inputs = [
"What's the capital of France?",
HumanMessage(content="What's the capital of Italy?"),
{"input": "What's the capital of Germany?"},
]

results = model_with_rails.batch(inputs)

assert len(results) == 3
assert results[0] == "Paris."
assert isinstance(results[1], AIMessage)
assert results[1].content == "Rome."
assert isinstance(results[2], dict)
assert results[2]["output"] == "Berlin."


def test_stream_output():
"""Test streaming output (simplified for now)."""
llm = FakeLLM(
responses=[
"Paris.",
]
)
config = RailsConfig.from_content(config={"models": []})
model_with_rails = RunnableRails(config, llm=llm)

chunks = []
for chunk in model_with_rails.stream("What's the capital of France?"):
chunks.append(chunk)

# Currently, stream just yields the full response as a single chunk
assert len(chunks) == 1
assert chunks[0].content == "Paris."


@pytest.mark.asyncio
async def test_astream_output():
"""Test async streaming output (simplified for now)."""
llm = FakeLLM(
responses=[
"hello what can you do?",
],
streaming=True,
)
config = RailsConfig.from_content(config={"models": [], "streaming": True})
model_with_rails = RunnableRails(config, llm=llm)

# Collect all chunks from the stream
chunks = []
async for chunk in model_with_rails.astream("What's the capital of France?"):
chunks.append(chunk)

# Stream should yield individual word chunks
assert len(chunks) == 5
assert chunks[0].content == "hello "
assert chunks[1].content == "what "
assert chunks[2].content == "can "
assert chunks[3].content == "you "
assert chunks[4].content == "do?"
Loading
Loading