Skip to content

test: Mock Claude models to improve LLM test reliability #1968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 52 additions & 46 deletions tests/system/small/ml/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable
from contextlib import AbstractContextManager, nullcontext
from typing import Any, Callable
from unittest import mock

import pandas as pd
Expand All @@ -25,6 +26,21 @@
from bigframes.testing import utils


@pytest.fixture(scope="function")
def text_generator_model(request, bq_connection, session):
"""Creates a text generator model, mocking creation for Claude models."""
model_class = request.param
if model_class == llm.Claude3TextGenerator:
# For Claude, mock the BQML model creation to avoid the network call
# that fails due to the region issue.
with mock.patch.object(llm.Claude3TextGenerator, "_create_bqml_model"):
model = model_class(connection_name=bq_connection, session=session)
else:
# For other models like Gemini, create as usual.
model = model_class(connection_name=bq_connection, session=session)
yield model


@pytest.mark.parametrize(
"model_name",
("text-embedding-005", "text-embedding-004", "text-multilingual-embedding-002"),
Expand Down Expand Up @@ -251,37 +267,35 @@ def __eq__(self, other):
return self.equals(other)


@pytest.mark.skip("b/436340035 test failed")
@pytest.mark.parametrize(
(
"model_class",
"options",
),
("text_generator_model", "options"),
[
(
pytest.param(
llm.GeminiTextGenerator,
{
"temperature": 0.9,
"max_output_tokens": 8192,
"top_p": 1.0,
"ground_with_google_search": False,
},
id="gemini",
),
(
pytest.param(
llm.Claude3TextGenerator,
{
"max_output_tokens": 128,
"top_k": 40,
"top_p": 0.95,
},
id="claude",
),
],
indirect=["text_generator_model"],
)
def test_text_generator_retry_success(
session,
model_class,
text_generator_model,
options,
bq_connection,
):
# Requests.
df0 = EqCmpAllDataFrame(
Expand All @@ -298,21 +312,13 @@ def test_text_generator_retry_success(
df1 = EqCmpAllDataFrame(
{
"ml_generate_text_status": ["error", "error"],
"prompt": [
"What is BQML?",
"What is BigQuery DataFrame?",
],
"prompt": ["What is BQML?", "What is BigQuery DataFrame?"],
},
index=[1, 2],
session=session,
)
df2 = EqCmpAllDataFrame(
{
"ml_generate_text_status": ["error"],
"prompt": [
"What is BQML?",
],
},
{"ml_generate_text_status": ["error"], "prompt": ["What is BQML?"]},
index=[1],
session=session,
)
Expand Down Expand Up @@ -342,31 +348,21 @@ def test_text_generator_retry_success(
EqCmpAllDataFrame(
{
"ml_generate_text_status": ["error", ""],
"prompt": [
"What is BQML?",
"What is BigQuery DataFrame?",
],
"prompt": ["What is BQML?", "What is BigQuery DataFrame?"],
},
index=[1, 2],
session=session,
),
EqCmpAllDataFrame(
{
"ml_generate_text_status": [""],
"prompt": [
"What is BQML?",
],
},
{"ml_generate_text_status": [""], "prompt": ["What is BQML?"]},
index=[1],
session=session,
),
]

text_generator_model = model_class(connection_name=bq_connection, session=session)
text_generator_model._bqml_model = mock_bqml_model

with mock.patch.object(core.BqmlModel, "generate_text_tvf", generate_text_tvf):
# 3rd retry isn't triggered
result = text_generator_model.predict(df0, max_retries=3)

mock_generate_text.assert_has_calls(
Expand All @@ -391,36 +387,36 @@ def test_text_generator_retry_success(
),
check_dtype=False,
check_index_type=False,
check_like=True,
)


@pytest.mark.skip("b/436340035 test failed")
@pytest.mark.parametrize(
(
"model_class",
"options",
),
("text_generator_model", "options"),
[
(
pytest.param(
llm.GeminiTextGenerator,
{
"temperature": 0.9,
"max_output_tokens": 8192,
"top_p": 1.0,
"ground_with_google_search": False,
},
id="gemini",
),
(
pytest.param(
llm.Claude3TextGenerator,
{
"max_output_tokens": 128,
"top_k": 40,
"top_p": 0.95,
},
id="claude",
),
],
indirect=["text_generator_model"],
)
def test_text_generator_retry_no_progress(session, model_class, options, bq_connection):
def test_text_generator_retry_no_progress(session, text_generator_model, options):
# Requests.
df0 = EqCmpAllDataFrame(
{
Expand Down Expand Up @@ -480,7 +476,6 @@ def test_text_generator_retry_no_progress(session, model_class, options, bq_conn
),
]

text_generator_model = model_class(connection_name=bq_connection, session=session)
text_generator_model._bqml_model = mock_bqml_model

with mock.patch.object(core.BqmlModel, "generate_text_tvf", generate_text_tvf):
Expand Down Expand Up @@ -508,10 +503,10 @@ def test_text_generator_retry_no_progress(session, model_class, options, bq_conn
),
check_dtype=False,
check_index_type=False,
check_like=True,
)


@pytest.mark.skip("b/436340035 test failed")
def test_text_embedding_generator_retry_success(session, bq_connection):
# Requests.
df0 = EqCmpAllDataFrame(
Expand Down Expand Up @@ -793,17 +788,28 @@ def test_gemini_preview_model_warnings(model_name):
llm.GeminiTextGenerator(model_name=model_name)


# b/436340035 temp disable the test to unblock presumbit
@pytest.mark.parametrize(
"model_class",
[
llm.TextEmbeddingGenerator,
llm.MultimodalEmbeddingGenerator,
llm.GeminiTextGenerator,
# llm.Claude3TextGenerator,
llm.Claude3TextGenerator,
],
)
def test_text_embedding_generator_no_default_model_warning(model_class):
message = "Since upgrading the default model can cause unintended breakages, the\ndefault model will be removed in BigFrames 3.0. Please supply an\nexplicit model to avoid this message."
with pytest.warns(FutureWarning, match=message):
model_class(model_name=None)

# For Claude models, we must mock the model creation to avoid network errors.
# For all other models, we do nothing. contextlib.nullcontext() is a
# placeholder that allows the "with" statement to work for all cases.
patcher: AbstractContextManager[Any]
if model_class == llm.Claude3TextGenerator:
patcher = mock.patch.object(llm.Claude3TextGenerator, "_create_bqml_model")
else:
# We can now call nullcontext() directly
patcher = nullcontext()

with patcher:
with pytest.warns(FutureWarning, match=message):
model_class(model_name=None)