Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion ax/api/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from ax.core.trial import Trial
from ax.core.trial_status import TrialStatus
from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy
from ax.exceptions.core import UnsupportedError
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.storage.sqa_store.db import init_test_engine_and_session_factory
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
Expand Down Expand Up @@ -488,6 +488,14 @@ def test_attach_data(self) -> None:
),
)

# With NaN / Inf values.
for value in [float("nan"), float("inf"), float("-inf")]:
with self.assertRaisesRegex(UserInputError, "null or inf values"):
client.attach_data(
trial_index=trial_index,
raw_data={"foo": (value, 0.0), "bar": (0.5, 0.0)},
)

def test_complete_trial(self) -> None:
client = Client()

Expand Down
4 changes: 4 additions & 0 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, cast, Union

import ax.core.observation as observation
import numpy as np
import pandas as pd
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
Expand Down Expand Up @@ -790,6 +791,9 @@ def attach_data(
data_init_args = data.deserialize_init_args(data.serialize_init_args(data))
if data.true_df.empty:
raise ValueError("Data to attach is empty.")
if not np.isfinite(data.true_df["mean"]).all():
# Error out if there are any NaNs or infs in the data.
raise UserInputError("Data to attach contains null or inf values.")
metrics_not_on_exp = set(data.true_df["metric_name"].values) - set(
self.metrics.keys()
)
Expand Down
39 changes: 38 additions & 1 deletion ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@
)
from ax.core.search_space import SearchSpace
from ax.core.types import ComparisonOp
from ax.exceptions.core import AxError, RunnerNotFoundError, UnsupportedError
from ax.exceptions.core import (
AxError,
RunnerNotFoundError,
UnsupportedError,
UserInputError,
)
from ax.metrics.branin import BraninMetric
from ax.runners.synthetic import SyntheticRunner
from ax.service.ax_client import AxClient
Expand Down Expand Up @@ -735,6 +740,38 @@ def test_attach_and_sort_data(self) -> None:
sorted_dfs[trial_index],
)

def test_attach_invalid_data(self) -> None:
experiment = self._setupBraninExperiment(n=1)
# Empty data.
empty_data = Data()
with self.assertRaisesRegex(ValueError, "is empty"):
experiment.attach_data(empty_data)

# Data with NaN / Inf.
for value in [None, float("nan"), float("inf")]:
data = Data(
df=pd.DataFrame(
[
{
"arm_name": "0_0",
"metric_name": "branin",
"mean": value,
"sem": 0.1,
"trial_index": 0,
},
{
"arm_name": "1_0",
"metric_name": "branin",
"mean": 5.0,
"sem": 0.1,
"trial_index": 1,
},
]
)
)
with self.assertRaisesRegex(UserInputError, "contains null or inf values"):
experiment.attach_data(data)

def test_immutable_search_space_and_opt_config(self) -> None:
mutable_exp = self._setupBraninExperiment(n=5)
self.assertFalse(mutable_exp.immutable_search_space_and_opt_config)
Expand Down
Loading