Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions bigframes/functions/_function_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def provision_bq_managed_function(
name,
packages,
is_row_processor,
bq_connection_id,
*,
capture_references=False,
):
Expand Down Expand Up @@ -273,12 +274,21 @@ def provision_bq_managed_function(
udf_code = textwrap.dedent(inspect.getsource(func))
udf_code = udf_code[udf_code.index("def") :]

with_connection_clause = (
(
f"WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{self._bq_connection_id}`"
)
if bq_connection_id
else ""
)

create_function_ddl = (
textwrap.dedent(
f"""
CREATE OR REPLACE FUNCTION {persistent_func_id}({','.join(bq_function_args)})
RETURNS {bq_function_return_type}
LANGUAGE python
{with_connection_clause}
OPTIONS ({managed_function_options_str})
AS r'''
__UDF_PLACE_HOLDER__
Expand Down
11 changes: 8 additions & 3 deletions bigframes/functions/_function_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,9 +807,13 @@ def udf(

bq_location, _ = _utils.get_remote_function_locations(bigquery_client.location)

# A connection is required for BQ managed function.
bq_connection_id = self._resolve_bigquery_connection_id(
session, dataset_ref, bq_location, bigquery_connection
# A connection is optional for BQ managed function.
bq_connection_id = (
self._resolve_bigquery_connection_id(
session, dataset_ref, bq_location, bigquery_connection
)
if bigquery_connection
else None
)

bq_connection_manager = session.bqconnectionmanager
Expand Down Expand Up @@ -907,6 +911,7 @@ def wrapper(func):
name=name,
packages=packages,
is_row_processor=is_row_processor,
bq_connection_id=bq_connection_id,
)

# TODO(shobs): Find a better way to support udfs with param named
Expand Down
9 changes: 7 additions & 2 deletions tests/system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,13 @@ def session_tokyo(tokyo_location: str) -> Generator[bigframes.Session, None, Non


@pytest.fixture(scope="session")
def bq_connection(bigquery_client: bigquery.Client) -> str:
return f"{bigquery_client.project}.{bigquery_client.location}.bigframes-rf-conn"
def bq_connection_name() -> str:
return "bigframes-rf-conn"


@pytest.fixture(scope="session")
def bq_connection(bigquery_client: bigquery.Client, bq_connection_name: str) -> str:
return f"{bigquery_client.project}.{bigquery_client.location}.{bq_connection_name}"


@pytest.fixture(scope="session", autouse=True)
Expand Down
48 changes: 43 additions & 5 deletions tests/system/large/functions/test_managed_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,7 @@ def featurize(x: int) -> list[float]:
cleanup_function_assets(featurize, session.bqclient, ignore_failures=False)


def test_managed_function_series_apply(
session,
scalars_dfs,
):
def test_managed_function_series_apply(session, scalars_dfs):
try:

@session.udf()
Expand Down Expand Up @@ -504,7 +501,10 @@ def test_managed_function_dataframe_apply_axis_1_array_output(session):

try:

@session.udf(input_types=[int, float, str], output_type=list[str])
@session.udf(
input_types=[int, float, str],
output_type=list[str],
)
def foo(x, y, z):
return [str(x), str(y), z]

Expand Down Expand Up @@ -587,3 +587,41 @@ def foo(x, y, z):
finally:
# Clean up the gcp assets created for the managed function.
cleanup_function_assets(foo, session.bqclient, ignore_failures=False)


@pytest.mark.parametrize(
"connection_fixture",
[
"bq_connection_name",
"bq_connection",
],
)
def test_managed_function_with_connection(
session, scalars_dfs, request, connection_fixture
):
try:
bigquery_connection = request.getfixturevalue(connection_fixture)

@session.udf(bigquery_connection=bigquery_connection)
def foo(x: int) -> int:
return x + 10

# Function should still work normally.
assert foo(-2) == 8

scalars_df, scalars_pandas_df = scalars_dfs

bf_result_col = scalars_df["int64_too"].apply(foo)
bf_result = (
scalars_df["int64_too"].to_frame().assign(result=bf_result_col).to_pandas()
)

pd_result_col = scalars_pandas_df["int64_too"].apply(foo)
pd_result = (
scalars_pandas_df["int64_too"].to_frame().assign(result=pd_result_col)
)

pandas.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)
finally:
# Clean up the gcp assets created for the managed function.
cleanup_function_assets(foo, session.bqclient, ignore_failures=False)