Skip to content

Commit b58fa9c

Browse files
committed
change the default model to None, when None is provide, change default model and raise warning
1 parent 7bad5ab commit b58fa9c

File tree

1 file changed

+37
-24
lines changed

1 file changed

+37
-24
lines changed

bigframes/ml/llm.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,18 @@ class TextEmbeddingGenerator(base.RetriableRemotePredictor):
127127
def __init__(
128128
self,
129129
*,
130-
model_name: Literal[
131-
"text-embedding-005",
132-
"text-embedding-004",
133-
"text-multilingual-embedding-002",
134-
] = "text-embedding-004",
130+
model_name: Optional[
131+
Literal[
132+
"text-embedding-005",
133+
"text-embedding-004",
134+
"text-multilingual-embedding-002",
135+
]
136+
] = None,
135137
session: Optional[bigframes.Session] = None,
136138
connection_name: Optional[str] = None,
137139
):
138-
if model_name == _TEXT_EMBEDDING_004_ENDPOINT:
140+
if model_name is None:
141+
model_name = "text-embedding-004"
139142
msg = exceptions.format_message(_REMOVE_DEFAULT_MODEL_WARNING)
140143
warnings.warn(msg, category=FutureWarning, stacklevel=2)
141144
self.model_name = model_name
@@ -274,13 +277,14 @@ class MultimodalEmbeddingGenerator(base.RetriableRemotePredictor):
274277
def __init__(
275278
self,
276279
*,
277-
model_name: Literal["multimodalembedding@001"] = "multimodalembedding@001",
280+
model_name: Optional[Literal["multimodalembedding@001"]] = None,
278281
session: Optional[bigframes.Session] = None,
279282
connection_name: Optional[str] = None,
280283
):
281284
if not bigframes.options.experiments.blob:
282285
raise NotImplementedError()
283-
if model_name == _MULTIMODAL_EMBEDDING_001_ENDPOINT:
286+
if model_name is None:
287+
model_name = "multimodalembedding@001"
284288
msg = exceptions.format_message(_REMOVE_DEFAULT_MODEL_WARNING)
285289
warnings.warn(msg, category=FutureWarning, stacklevel=2)
286290
self.model_name = model_name
@@ -440,17 +444,19 @@ class GeminiTextGenerator(base.RetriableRemotePredictor):
440444
def __init__(
441445
self,
442446
*,
443-
model_name: Literal[
444-
"gemini-1.5-pro-preview-0514",
445-
"gemini-1.5-flash-preview-0514",
446-
"gemini-1.5-pro-001",
447-
"gemini-1.5-pro-002",
448-
"gemini-1.5-flash-001",
449-
"gemini-1.5-flash-002",
450-
"gemini-2.0-flash-exp",
451-
"gemini-2.0-flash-001",
452-
"gemini-2.0-flash-lite-001",
453-
] = "gemini-2.0-flash-001",
447+
model_name: Optional[
448+
Literal[
449+
"gemini-1.5-pro-preview-0514",
450+
"gemini-1.5-flash-preview-0514",
451+
"gemini-1.5-pro-001",
452+
"gemini-1.5-pro-002",
453+
"gemini-1.5-flash-001",
454+
"gemini-1.5-flash-002",
455+
"gemini-2.0-flash-exp",
456+
"gemini-2.0-flash-001",
457+
"gemini-2.0-flash-lite-001",
458+
]
459+
] = None,
454460
session: Optional[bigframes.Session] = None,
455461
connection_name: Optional[str] = None,
456462
max_iterations: int = 300,
@@ -465,7 +471,8 @@ def __init__(
465471
"(https://cloud.google.com/products#product-launch-stages)."
466472
)
467473
warnings.warn(msg, category=exceptions.PreviewWarning)
468-
if model_name == _GEMINI_2_FLASH_001_ENDPOINT:
474+
if model_name is None:
475+
model_name = "gemini-2.0-flash-001"
469476
msg = exceptions.format_message(_REMOVE_DEFAULT_MODEL_WARNING)
470477
warnings.warn(msg, category=FutureWarning, stacklevel=2)
471478
self.model_name = model_name
@@ -830,13 +837,19 @@ class Claude3TextGenerator(base.RetriableRemotePredictor):
830837
def __init__(
831838
self,
832839
*,
833-
model_name: Literal[
834-
"claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus"
835-
] = "claude-3-sonnet",
840+
model_name: Optional[
841+
Literal[
842+
"claude-3-sonnet",
843+
"claude-3-haiku",
844+
"claude-3-5-sonnet",
845+
"claude-3-opus",
846+
]
847+
] = None,
836848
session: Optional[bigframes.Session] = None,
837849
connection_name: Optional[str] = None,
838850
):
839-
if model_name == _CLAUDE_3_SONNET_ENDPOINT:
851+
if model_name is None:
852+
model_name = "claude-3-sonnet"
840853
msg = exceptions.format_message(_REMOVE_DEFAULT_MODEL_WARNING)
841854
warnings.warn(msg, category=FutureWarning, stacklevel=2)
842855
self.model_name = model_name

0 commit comments

Comments
 (0)