Skip to content

feat: use ADK built-in BigQuery tools in the data science agent #298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions python/agents/data-science/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ Evaluation tests assess the overall performance and capabilities of the agent in

**Run Evaluation Tests:**

```bash
poetry run pytest eval
```
```bash
poetry run pytest eval
```


- This command executes all test files within the `eval/` directory.
Expand All @@ -266,9 +266,9 @@ Tests assess the overall executability of the agents.

**Run Tests:**

```bash
poetry run pytest tests
```
```bash
poetry run pytest tests
```

- This command executes all test files within the `tests/` directory.
- `poetry run` ensures that pytest runs within the project's virtual environment.
Expand Down
2 changes: 1 addition & 1 deletion python/agents/data-science/data_science/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def setup_before_agent_call(callback_context: CallbackContext):
# setting up schema in instruction
if callback_context.state["all_db_settings"]["use_database"] == "BigQuery":
callback_context.state["database_settings"] = get_bq_database_settings()
schema = callback_context.state["database_settings"]["bq_ddl_schema"]
schema = callback_context.state["database_settings"]["bq_schema_and_samples"]

callback_context._invocation_context.agent.instruction = (
return_instructions_root()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@

import os

from typing import Any, Dict, Optional

from google.adk.agents import Agent
from google.adk.agents.callback_context import CallbackContext
from google.adk.tools import BaseTool, ToolContext
from google.adk.tools.bigquery import BigQueryToolset
from google.adk.tools.bigquery.config import BigQueryToolConfig, WriteMode
from google.genai import types

from . import tools
Expand All @@ -26,6 +31,10 @@

NL2SQL_METHOD = os.getenv("NL2SQL_METHOD", "BASELINE")

# BigQuery built-in tools in ADK
# https://google.github.io/adk-docs/tools/built-in-tools/#bigquery
ADK_BUILTIN_BQ_EXECUTE_SQL_TOOL = "execute_sql"


def setup_before_agent_call(callback_context: CallbackContext) -> None:
"""Setup the agent."""
Expand All @@ -35,6 +44,29 @@ def setup_before_agent_call(callback_context: CallbackContext) -> None:
tools.get_database_settings()


def handle_after_tool_call(
tool: BaseTool, args: Dict[str, Any], tool_context: ToolContext, tool_response: Dict
) -> Optional[Dict]:

# We are setting a state for the data science agent to be able to use the sql
# query results as context
if tool.name == ADK_BUILTIN_BQ_EXECUTE_SQL_TOOL:
if tool_response["status"] == "SUCCESS":
tool_context.state["query_result"] = tool_response["rows"]

return None


bigquery_tool_filter = [ADK_BUILTIN_BQ_EXECUTE_SQL_TOOL]
bigquery_tool_config = BigQueryToolConfig(
write_mode=WriteMode.BLOCKED,
max_query_result_rows=80
)
bigquery_toolset = BigQueryToolset(
tool_filter=bigquery_tool_filter,
bigquery_tool_config=bigquery_tool_config
)

database_agent = Agent(
model=os.getenv("BIGQUERY_AGENT_MODEL"),
name="database_agent",
Expand All @@ -45,8 +77,9 @@ def setup_before_agent_call(callback_context: CallbackContext) -> None:
if NL2SQL_METHOD == "CHASE"
else tools.initial_bq_nl2sql
),
tools.run_bigquery_validation,
bigquery_toolset,
],
before_agent_callback=setup_before_agent_call,
after_tool_callback=handle_after_tool_call,
generate_content_config=types.GenerateContentConfig(temperature=0.01),
)
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def initial_bq_nl2sql(
str: An SQL statement to answer this question.
"""
print("****** Running agent with ChaseSQL algorithm.")
ddl_schema = tool_context.state["database_settings"]["bq_ddl_schema"]
bq_schema_and_samples = tool_context.state["database_settings"]["bq_schema_and_samples"]
project = tool_context.state["database_settings"]["bq_data_project_id"]
db = tool_context.state["database_settings"]["bq_dataset_id"]
transpile_to_bigquery = tool_context.state["database_settings"][
Expand All @@ -114,13 +114,13 @@ def initial_bq_nl2sql(

if generate_sql_type == GenerateSQLType.DC.value:
prompt = DC_PROMPT_TEMPLATE.format(
SCHEMA=ddl_schema,
SCHEMA=bq_schema_and_samples,
QUESTION=question,
BQ_DATA_PROJECT_ID=BQ_DATA_PROJECT_ID
)
elif generate_sql_type == GenerateSQLType.QP.value:
prompt = QP_PROMPT_TEMPLATE.format(
SCHEMA=ddl_schema,
SCHEMA=bq_schema_and_samples,
QUESTION=question,
BQ_DATA_PROJECT_ID=BQ_DATA_PROJECT_ID
)
Expand All @@ -145,7 +145,7 @@ def initial_bq_nl2sql(
# pylint: disable=g-bad-todo
# pylint: enable=g-bad-todo
responses: str = translator.translate(
responses, ddl_schema=ddl_schema, db=db, catalog=project
responses, ddl_schema=bq_schema_and_samples, db=db, catalog=project
)

return responses
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""

import os
from data_science.utils.utils import get_env_var


def return_instructions_bigquery() -> str:
Expand All @@ -30,7 +31,30 @@ def return_instructions_bigquery() -> str:
db_tool_name = None
raise ValueError(f"Unknown NL2SQL method: {NL2SQL_METHOD}")

instruction_prompt_bqml_v1 = f"""
instruction_prompt_bq_v2 = f"""
You are an AI assistant serving as a SQL expert for BigQuery.
Your job is to help users generate SQL answers from natural language questions (inside Nl2sqlInput).
You should proeuce the result as NL2SQLOutput.

Use the provided tools to help generate the most accurate SQL:
1. First, use {db_tool_name} tool to generate initial SQL from the question.
2. Then you should use the execute_sql tool to validate and execute the SQL. If there are any errors with the SQL, you should go back to step 1 and recreate the SQL by addressing the error.
4. Generate the final result in JSON format with four keys: "explain", "sql", "sql_results", "nl_results".
"explain": "write out step-by-step reasoning to explain how you are generating the query based on the schema, example, and question.",
"sql": "Output your generated SQL!",
"sql_results": "raw sql execution query_result from execute_sql if it's available, otherwise None",
"nl_results": "Natural language about results, otherwise it's None if generated SQL is invalid"

You should pass one tool call to another tool call as needed!

NOTE: you should ALWAYS USE THE TOOL {db_tool_name} to generate SQL, not make up SQL WITHOUT CALLING TOOLS.
Keep in mind that you are an orchestration agent, not a SQL expert, so use the tools to help you generate SQL, but do not make up SQL.

NOTE: you must ALWAYS PASS the project_id {get_env_var("BQ_COMPUTE_PROJECT_ID")} to the execute_sql tool. DO NOT pass any other project id.

"""

instruction_prompt_bq_v1 = f"""
You are an AI assistant serving as a SQL expert for BigQuery.
Your job is to help users generate SQL answers from natural language questions (inside Nl2sqlInput).
You should proeuce the result as NL2SQLOutput.
Expand All @@ -51,4 +75,4 @@ def return_instructions_bigquery() -> str:

"""

return instruction_prompt_bqml_v1
return instruction_prompt_bq_v2
Loading