@@ -22,10 +22,9 @@ import (
2222)
2323
2424type Client struct {
25- clientsLock sync.Mutex
25+ modelsLock sync.Mutex
2626 cache * cache.Client
27- clients map [string ]* openai.Client
28- models map [string ]* openai.Client
27+ modelToProvider map [string ]string
2928 runner * runner.Runner
3029 envs []string
3130 credStore credentials.CredentialStore
@@ -43,14 +42,19 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent
4342}
4443
4544func (c * Client ) Call (ctx context.Context , messageRequest types.CompletionRequest , status chan <- types.CompletionStatus ) (* types.CompletionMessage , error ) {
46- c .clientsLock .Lock ()
47- client , ok := c .models [messageRequest .Model ]
48- c .clientsLock .Unlock ()
45+ c .modelsLock .Lock ()
46+ provider , ok := c .modelToProvider [messageRequest .Model ]
47+ c .modelsLock .Unlock ()
4948
5049 if ! ok {
5150 return nil , fmt .Errorf ("failed to find remote model %s" , messageRequest .Model )
5251 }
5352
53+ client , err := c .load (ctx , provider )
54+ if err != nil {
55+ return nil , err
56+ }
57+
5458 toolName , modelName := types .SplitToolRef (messageRequest .Model )
5559 if modelName == "" {
5660 // modelName is empty, then the messageRequest.Model is not of the form 'modelName from provider'
@@ -96,19 +100,19 @@ func (c *Client) Supports(ctx context.Context, modelString string) (bool, error)
96100 return false , nil
97101 }
98102
99- client , err := c .load (ctx , providerName )
103+ _ , err := c .load (ctx , providerName )
100104 if err != nil {
101105 return false , err
102106 }
103107
104- c .clientsLock .Lock ()
105- defer c .clientsLock .Unlock ()
108+ c .modelsLock .Lock ()
109+ defer c .modelsLock .Unlock ()
106110
107- if c .models == nil {
108- c .models = map [string ]* openai. Client {}
111+ if c .modelToProvider == nil {
112+ c .modelToProvider = map [string ]string {}
109113 }
110114
111- c .models [modelString ] = client
115+ c .modelToProvider [modelString ] = providerName
112116 return true , nil
113117}
114118
@@ -141,24 +145,11 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie
141145}
142146
143147func (c * Client ) load (ctx context.Context , toolName string ) (* openai.Client , error ) {
144- c .clientsLock .Lock ()
145- defer c .clientsLock .Unlock ()
146-
147- client , ok := c .clients [toolName ]
148- if ok {
149- return client , nil
150- }
151-
152- if c .clients == nil {
153- c .clients = make (map [string ]* openai.Client )
154- }
155-
156148 if isHTTPURL (toolName ) {
157149 remoteClient , err := c .clientFromURL (ctx , toolName )
158150 if err != nil {
159151 return nil , err
160152 }
161- c .clients [toolName ] = remoteClient
162153 return remoteClient , nil
163154 }
164155
@@ -174,22 +165,15 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
174165 return nil , err
175166 }
176167
177- if strings .HasSuffix (url , "/" ) {
178- url += "v1"
179- } else {
180- url += "/v1"
181- }
182-
183- client , err = openai .NewClient (ctx , c .credStore , openai.Options {
184- BaseURL : url ,
168+ client , err := openai .NewClient (ctx , c .credStore , openai.Options {
169+ BaseURL : strings .TrimSuffix (url , "/" ) + "/v1" ,
185170 Cache : c .cache ,
186171 CacheKey : prg .EntryToolID ,
187172 })
188173 if err != nil {
189174 return nil , err
190175 }
191176
192- c .clients [toolName ] = client
193177 return client , nil
194178}
195179
0 commit comments