11# coding=utf-8
22import threading
3- from typing import Dict
3+ from typing import Dict , Optional , List , Any
44
55from langchain_community .embeddings import XinferenceEmbeddings
6+ from langchain_core .embeddings import Embeddings
67
78from setting .models_provider .base_model_provider import MaxKBBaseModel
89
910
10- class XinferenceEmbedding (MaxKBBaseModel , XinferenceEmbeddings ):
11+ class XinferenceEmbedding (MaxKBBaseModel , Embeddings ):
12+ client : Any
13+ server_url : Optional [str ]
14+ """URL of the xinference server"""
15+ model_uid : Optional [str ]
16+ """UID of the launched model"""
17+
1118 @staticmethod
1219 def new_instance (model_type , model_name , model_credential : Dict [str , object ], ** model_kwargs ):
1320 return XinferenceEmbedding (
1421 model_uid = model_name ,
1522 server_url = model_credential .get ('api_base' ),
23+ api_key = model_credential .get ('api_key' ),
1624 )
1725
1826 def down_model (self ):
@@ -22,3 +30,63 @@ def start_down_model_thread(self):
2230 thread = threading .Thread (target = self .down_model )
2331 thread .daemon = True
2432 thread .start ()
33+
34+ def __init__ (
35+ self , server_url : Optional [str ] = None , model_uid : Optional [str ] = None ,
36+ api_key : Optional [str ] = None
37+ ):
38+ try :
39+ from xinference .client import RESTfulClient
40+ except ImportError :
41+ try :
42+ from xinference_client import RESTfulClient
43+ except ImportError as e :
44+ raise ImportError (
45+ "Could not import RESTfulClient from xinference. Please install it"
46+ " with `pip install xinference` or `pip install xinference_client`."
47+ ) from e
48+
49+ if server_url is None :
50+ raise ValueError ("Please provide server URL" )
51+
52+ if model_uid is None :
53+ raise ValueError ("Please provide the model UID" )
54+
55+ self .server_url = server_url
56+
57+ self .model_uid = model_uid
58+
59+ self .api_key = api_key
60+
61+ self .client = RESTfulClient (server_url , api_key )
62+
63+ def embed_documents (self , texts : List [str ]) -> List [List [float ]]:
64+ """Embed a list of documents using Xinference.
65+ Args:
66+ texts: The list of texts to embed.
67+ Returns:
68+ List of embeddings, one for each text.
69+ """
70+
71+ model = self .client .get_model (self .model_uid )
72+
73+ embeddings = [
74+ model .create_embedding (text )["data" ][0 ]["embedding" ] for text in texts
75+ ]
76+ return [list (map (float , e )) for e in embeddings ]
77+
78+ def embed_query (self , text : str ) -> List [float ]:
79+ """Embed a query of documents using Xinference.
80+ Args:
81+ text: The text to embed.
82+ Returns:
83+ Embeddings for the text.
84+ """
85+
86+ model = self .client .get_model (self .model_uid )
87+
88+ embedding_res = model .create_embedding (text )
89+
90+ embedding = embedding_res ["data" ][0 ]["embedding" ]
91+
92+ return list (map (float , embedding ))
0 commit comments