@@ -127,15 +127,18 @@ class TextEmbeddingGenerator(base.RetriableRemotePredictor):
127
127
def __init__ (
128
128
self ,
129
129
* ,
130
- model_name : Literal [
131
- "text-embedding-005" ,
132
- "text-embedding-004" ,
133
- "text-multilingual-embedding-002" ,
134
- ] = "text-embedding-004" ,
130
+ model_name : Optional [
131
+ Literal [
132
+ "text-embedding-005" ,
133
+ "text-embedding-004" ,
134
+ "text-multilingual-embedding-002" ,
135
+ ]
136
+ ] = None ,
135
137
session : Optional [bigframes .Session ] = None ,
136
138
connection_name : Optional [str ] = None ,
137
139
):
138
- if model_name == _TEXT_EMBEDDING_004_ENDPOINT :
140
+ if model_name is None :
141
+ model_name = "text-embedding-004"
139
142
msg = exceptions .format_message (_REMOVE_DEFAULT_MODEL_WARNING )
140
143
warnings .warn (msg , category = FutureWarning , stacklevel = 2 )
141
144
self .model_name = model_name
@@ -274,13 +277,14 @@ class MultimodalEmbeddingGenerator(base.RetriableRemotePredictor):
274
277
def __init__ (
275
278
self ,
276
279
* ,
277
- model_name : Literal ["multimodalembedding@001" ] = "multimodalembedding@001" ,
280
+ model_name : Optional [ Literal ["multimodalembedding@001" ]] = None ,
278
281
session : Optional [bigframes .Session ] = None ,
279
282
connection_name : Optional [str ] = None ,
280
283
):
281
284
if not bigframes .options .experiments .blob :
282
285
raise NotImplementedError ()
283
- if model_name == _MULTIMODAL_EMBEDDING_001_ENDPOINT :
286
+ if model_name is None :
287
+ model_name = "multimodalembedding@001"
284
288
msg = exceptions .format_message (_REMOVE_DEFAULT_MODEL_WARNING )
285
289
warnings .warn (msg , category = FutureWarning , stacklevel = 2 )
286
290
self .model_name = model_name
@@ -440,17 +444,19 @@ class GeminiTextGenerator(base.RetriableRemotePredictor):
440
444
def __init__ (
441
445
self ,
442
446
* ,
443
- model_name : Literal [
444
- "gemini-1.5-pro-preview-0514" ,
445
- "gemini-1.5-flash-preview-0514" ,
446
- "gemini-1.5-pro-001" ,
447
- "gemini-1.5-pro-002" ,
448
- "gemini-1.5-flash-001" ,
449
- "gemini-1.5-flash-002" ,
450
- "gemini-2.0-flash-exp" ,
451
- "gemini-2.0-flash-001" ,
452
- "gemini-2.0-flash-lite-001" ,
453
- ] = "gemini-2.0-flash-001" ,
447
+ model_name : Optional [
448
+ Literal [
449
+ "gemini-1.5-pro-preview-0514" ,
450
+ "gemini-1.5-flash-preview-0514" ,
451
+ "gemini-1.5-pro-001" ,
452
+ "gemini-1.5-pro-002" ,
453
+ "gemini-1.5-flash-001" ,
454
+ "gemini-1.5-flash-002" ,
455
+ "gemini-2.0-flash-exp" ,
456
+ "gemini-2.0-flash-001" ,
457
+ "gemini-2.0-flash-lite-001" ,
458
+ ]
459
+ ] = None ,
454
460
session : Optional [bigframes .Session ] = None ,
455
461
connection_name : Optional [str ] = None ,
456
462
max_iterations : int = 300 ,
@@ -465,7 +471,8 @@ def __init__(
465
471
"(https://cloud.google.com/products#product-launch-stages)."
466
472
)
467
473
warnings .warn (msg , category = exceptions .PreviewWarning )
468
- if model_name == _GEMINI_2_FLASH_001_ENDPOINT :
474
+ if model_name is None :
475
+ model_name = "gemini-2.0-flash-001"
469
476
msg = exceptions .format_message (_REMOVE_DEFAULT_MODEL_WARNING )
470
477
warnings .warn (msg , category = FutureWarning , stacklevel = 2 )
471
478
self .model_name = model_name
@@ -830,13 +837,19 @@ class Claude3TextGenerator(base.RetriableRemotePredictor):
830
837
def __init__ (
831
838
self ,
832
839
* ,
833
- model_name : Literal [
834
- "claude-3-sonnet" , "claude-3-haiku" , "claude-3-5-sonnet" , "claude-3-opus"
835
- ] = "claude-3-sonnet" ,
840
+ model_name : Optional [
841
+ Literal [
842
+ "claude-3-sonnet" ,
843
+ "claude-3-haiku" ,
844
+ "claude-3-5-sonnet" ,
845
+ "claude-3-opus" ,
846
+ ]
847
+ ] = None ,
836
848
session : Optional [bigframes .Session ] = None ,
837
849
connection_name : Optional [str ] = None ,
838
850
):
839
- if model_name == _CLAUDE_3_SONNET_ENDPOINT :
851
+ if model_name is None :
852
+ model_name = "claude-3-sonnet"
840
853
msg = exceptions .format_message (_REMOVE_DEFAULT_MODEL_WARNING )
841
854
warnings .warn (msg , category = FutureWarning , stacklevel = 2 )
842
855
self .model_name = model_name
0 commit comments