From 0b2efd8e2719a9e52d1e4eb13d4ab81433cd3a73 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Fri, 13 Jan 2023 21:40:38 -0800 Subject: [PATCH 01/30] wip --- answers.go | 2 +- api.go | 2 +- api_test.go | 109 +++++++++++++++++++++++++------------------------- common.go | 3 +- completion.go | 2 +- edits.go | 2 +- embeddings.go | 2 +- engines.go | 2 +- error.go | 2 +- files.go | 2 +- go.mod | 4 +- go.sum | 2 + image.go | 2 +- moderation.go | 2 +- 14 files changed, 69 insertions(+), 69 deletions(-) create mode 100644 go.sum diff --git a/answers.go b/answers.go index 3a20f2a..3905d68 100644 --- a/answers.go +++ b/answers.go @@ -1,4 +1,4 @@ -package gogpt +package openai import ( "bytes" diff --git a/api.go b/api.go index c339afe..c54b3b7 100644 --- a/api.go +++ b/api.go @@ -1,4 +1,4 @@ -package gogpt +package openai import ( "encoding/json" diff --git a/api_test.go b/api_test.go index 1e1c5d0..c3e9cf6 100644 --- a/api_test.go +++ b/api_test.go @@ -1,10 +1,11 @@ -package gogpt_test +package openai_test import ( "bytes" "context" "encoding/json" "fmt" + "github.com/fabiustech/openai" "io/ioutil" "log" "net/http" @@ -14,8 +15,6 @@ import ( "strings" "testing" "time" - - . "github.com/sashabaranov/go-gpt3" ) const ( @@ -29,7 +28,7 @@ func TestAPI(t *testing.T) { } var err error - c := NewClient(apiToken) + c := openai.NewClient(apiToken) ctx := context.Background() _, err = c.ListEngines(ctx) if err != nil { @@ -53,12 +52,12 @@ func TestAPI(t *testing.T) { } } // else skip - embeddingReq := EmbeddingRequest{ + embeddingReq := openai.EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", }, - Model: AdaSearchQuery, + Model: openai.AdaSearchQuery, } _, err = c.CreateEmbeddings(ctx, embeddingReq) if err != nil { @@ -74,11 +73,11 @@ func TestCompletions(t *testing.T) { ts.Start() defer ts.Close() - client := NewClient(testAPIToken) + client := openai.NewClient(testAPIToken) ctx := context.Background() client.BaseURL = ts.URL + "/v1" - req := CompletionRequest{ + req := openai.CompletionRequest{ MaxTokens: 5, Model: "ada", } @@ -97,13 +96,13 @@ func TestEdits(t *testing.T) { ts.Start() defer ts.Close() - client := NewClient(testAPIToken) + client := openai.NewClient(testAPIToken) ctx := context.Background() client.BaseURL = ts.URL + "/v1" // create an edit request model := "ada" - editReq := EditsRequest{ + editReq := openai.EditsRequest{ Model: &model, Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + @@ -122,26 +121,26 @@ func TestEdits(t *testing.T) { } func TestEmbedding(t *testing.T) { - embeddedModels := []EmbeddingModel{ - AdaSimilarity, - BabbageSimilarity, - CurieSimilarity, - DavinciSimilarity, - AdaSearchDocument, - AdaSearchQuery, - BabbageSearchDocument, - BabbageSearchQuery, - CurieSearchDocument, - CurieSearchQuery, - DavinciSearchDocument, - DavinciSearchQuery, - AdaCodeSearchCode, - AdaCodeSearchText, - BabbageCodeSearchCode, - BabbageCodeSearchText, + embeddedModels := []openai.EmbeddingModel{ + openai.AdaSimilarity, + openai.BabbageSimilarity, + openai.CurieSimilarity, + openai.DavinciSimilarity, + openai.AdaSearchDocument, + openai.AdaSearchQuery, + openai.BabbageSearchDocument, + openai.BabbageSearchQuery, + openai.CurieSearchDocument, + openai.CurieSearchQuery, + openai.DavinciSearchDocument, + openai.DavinciSearchQuery, + openai.AdaCodeSearchCode, + openai.AdaCodeSearchText, + openai.BabbageCodeSearchCode, + openai.BabbageCodeSearchText, } for _, model := range embeddedModels { - embeddingReq := EmbeddingRequest{ + embeddingReq := openai.EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", @@ -161,16 +160,16 @@ func TestEmbedding(t *testing.T) { } // getEditBody Returns the body of the request to create an edit. -func getEditBody(r *http.Request) (EditsRequest, error) { - edit := EditsRequest{} +func getEditBody(r *http.Request) (openai.EditsRequest, error) { + edit := openai.EditsRequest{} // read the request body reqBody, err := ioutil.ReadAll(r.Body) if err != nil { - return EditsRequest{}, err + return openai.EditsRequest{}, err } err = json.Unmarshal(reqBody, &edit) if err != nil { - return EditsRequest{}, err + return openai.EditsRequest{}, err } return edit, nil } @@ -184,14 +183,14 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var editReq EditsRequest + var editReq openai.EditsRequest editReq, err = getEditBody(r) if err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } // create a response - res := EditsResponse{ + res := openai.EditsResponse{ Object: "test-object", Created: uint64(time.Now().Unix()), } @@ -201,12 +200,12 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { completionTokens := int(float32(len(editString))/4) * editReq.N for i := 0; i < editReq.N; i++ { // instruction will be hidden and only seen by OpenAI - res.Choices = append(res.Choices, EditsChoice{ + res.Choices = append(res.Choices, openai.EditsChoice{ Text: editReq.Input + editString, Index: i, }) } - res.Usage = Usage{ + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -224,12 +223,12 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var completionReq CompletionRequest + var completionReq openai.CompletionRequest if completionReq, err = getCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := CompletionResponse{ + res := openai.CompletionResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Object: "test-object", Created: uint64(time.Now().Unix()), @@ -245,14 +244,14 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if completionReq.Echo { completionStr = completionReq.Prompt + completionStr } - res.Choices = append(res.Choices, CompletionChoice{ + res.Choices = append(res.Choices, openai.CompletionChoice{ Text: completionStr, Index: i, }) } inputTokens := numTokens(completionReq.Prompt) * completionReq.N completionTokens := completionReq.MaxTokens * completionReq.N - res.Usage = Usage{ + res.Usage = openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -270,20 +269,20 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var imageReq ImageRequest + var imageReq openai.ImageRequest if imageReq, err = getImageBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := ImageResponse{ + res := openai.ImageResponse{ Created: uint64(time.Now().Unix()), } for i := 0; i < imageReq.N; i++ { - imageData := ImageResponseDataInner{} + imageData := openai.ImageResponseDataInner{} switch imageReq.ResponseFormat { - case CreateImageResponseFormatURL, "": + case openai.CreateImageResponseFormatURL, "": imageData.URL = "https://example.com/image.png" - case CreateImageResponseFormatB64JSON: + case openai.CreateImageResponseFormatB64JSON: // This decodes to "{}" in base64. imageData.B64JSON = "e30K" default: @@ -297,31 +296,31 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { } // getCompletionBody Returns the body of the request to create a completion. -func getCompletionBody(r *http.Request) (CompletionRequest, error) { - completion := CompletionRequest{} +func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) { + completion := openai.CompletionRequest{} // read the request body reqBody, err := ioutil.ReadAll(r.Body) if err != nil { - return CompletionRequest{}, err + return openai.CompletionRequest{}, err } err = json.Unmarshal(reqBody, &completion) if err != nil { - return CompletionRequest{}, err + return openai.CompletionRequest{}, err } return completion, nil } // getImageBody Returns the body of the request to create a image. -func getImageBody(r *http.Request) (ImageRequest, error) { - image := ImageRequest{} +func getImageBody(r *http.Request) (openai.ImageRequest, error) { + image := openai.ImageRequest{} // read the request body reqBody, err := ioutil.ReadAll(r.Body) if err != nil { - return ImageRequest{}, err + return openai.ImageRequest{}, err } err = json.Unmarshal(reqBody, &image) if err != nil { - return ImageRequest{}, err + return openai.ImageRequest{}, err } return image, nil } @@ -342,11 +341,11 @@ func TestImages(t *testing.T) { ts.Start() defer ts.Close() - client := NewClient(testAPIToken) + client := openai.NewClient(testAPIToken) ctx := context.Background() client.BaseURL = ts.URL + "/v1" - req := ImageRequest{} + req := openai.ImageRequest{} req.Prompt = "Lorem ipsum" _, err = client.CreateImage(ctx, req) if err != nil { diff --git a/common.go b/common.go index 9fb0178..3121ede 100644 --- a/common.go +++ b/common.go @@ -1,5 +1,4 @@ -// common.go defines common types used throughout the OpenAI API. -package gogpt +package openai // Usage Represents the total token usage per request to OpenAI. type Usage struct { diff --git a/completion.go b/completion.go index 97601c3..cbc5645 100644 --- a/completion.go +++ b/completion.go @@ -1,4 +1,4 @@ -package gogpt +package openai import ( "bytes" diff --git a/edits.go b/edits.go index 8101429..a63875a 100644 --- a/edits.go +++ b/edits.go @@ -1,4 +1,4 @@ -package gogpt +package openai import ( "bytes" diff --git a/embeddings.go b/embeddings.go index 52c5223..3518cd9 100644 --- a/embeddings.go +++ b/embeddings.go @@ -1,4 +1,4 @@ -package gogpt +package openai import ( "bytes" diff --git a/engines.go b/engines.go index 3f82e98..350241b 100644 --- a/engines.go +++ b/engines.go @@ -1,4 +1,4 @@ -package gogpt +package openai import ( "context" diff --git a/error.go b/error.go index 4d0a324..3f65b66 100644 --- a/error.go +++ b/error.go @@ -1,4 +1,4 @@ -package gogpt +package openai type ErrorResponse struct { Error *struct { diff --git a/files.go b/files.go index 672f060..0dfb41b 100644 --- a/files.go +++ b/files.go @@ -1,4 +1,4 @@ -package gogpt +package openai import ( "bytes" diff --git a/go.mod b/go.mod index 4b6bb42..aaba303 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ -module github.com/sashabaranov/go-gpt3 +module github.com/fabiustech/openai -go 1.17 +go 1.19 \ No newline at end of file diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..18421b0 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/sashabaranov/go-gpt3 v0.0.0-20230112080207-81b5788cd68e h1:F+NLURcyhx2LWefyiDwLuwRzSsraL2hDyq0pPS853bo= +github.com/sashabaranov/go-gpt3 v0.0.0-20230112080207-81b5788cd68e/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ= diff --git a/image.go b/image.go index 335e82f..03292d6 100644 --- a/image.go +++ b/image.go @@ -1,4 +1,4 @@ -package gogpt +package openai import ( "bytes" diff --git a/moderation.go b/moderation.go index 1058693..54f4ff3 100644 --- a/moderation.go +++ b/moderation.go @@ -1,4 +1,4 @@ -package gogpt +package openai import ( "bytes" From 023b6540e0888a79827e914cc41e838ed90b5b17 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Fri, 13 Jan 2023 21:40:53 -0800 Subject: [PATCH 02/30] go mod tidy --- go.mod | 2 +- go.sum | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/go.mod b/go.mod index aaba303..54ca0af 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/fabiustech/openai -go 1.19 \ No newline at end of file +go 1.19 diff --git a/go.sum b/go.sum index 18421b0..e69de29 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +0,0 @@ -github.com/sashabaranov/go-gpt3 v0.0.0-20230112080207-81b5788cd68e h1:F+NLURcyhx2LWefyiDwLuwRzSsraL2hDyq0pPS853bo= -github.com/sashabaranov/go-gpt3 v0.0.0-20230112080207-81b5788cd68e/go.mod h1:BIZdbwdzxZbCrcKGMGH6u2eyGe1xFuX9Anmh3tCP8lQ= From 70f3cc9409dbc153367d04d1ddc4694bc1fae2a5 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Fri, 13 Jan 2023 21:42:06 -0800 Subject: [PATCH 03/30] update readme --- README.md | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index b02976d..6d9b09e 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Installation: ``` -go get github.com/sashabaranov/go-gpt3 +go get github.com/fabiustech/openai ``` @@ -19,15 +19,16 @@ package main import ( "context" "fmt" - gogpt "github.com/sashabaranov/go-gpt3" + + "github.com/fabiustech/openai" ) func main() { - c := gogpt.NewClient("your token") + c := openai.NewClient("your token") ctx := context.Background() - req := gogpt.CompletionRequest{ - Model: gogpt.GPT3Ada, + req := openai.CompletionRequest{ + Model: openai.GPT3Ada, MaxTokens: 5, Prompt: "Lorem ipsum", } From 478df094d5cf674ef8dcd50da14138cd9a3b7ee7 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Sat, 14 Jan 2023 08:01:53 -0800 Subject: [PATCH 04/30] progress --- answers.go | 24 ++++--- completion.go | 149 ++++++++++++++++++++++++++---------------- embeddings.go | 99 ++-------------------------- models/completions.go | 42 ++++++++++++ models/embeddings.go | 89 +++++++++++++++++++++++++ models/enum.go | 1 + 6 files changed, 244 insertions(+), 160 deletions(-) create mode 100644 models/completions.go create mode 100644 models/embeddings.go create mode 100644 models/enum.go diff --git a/answers.go b/answers.go index 3905d68..bd2b463 100644 --- a/answers.go +++ b/answers.go @@ -32,20 +32,24 @@ type AnswerResponse struct { } `json:"selected_documents"` } -// Search — perform a semantic search api call over a list of documents. -func (c *Client) Answers(ctx context.Context, request AnswerRequest) (response AnswerResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) +// Answers ... +func (c *Client) Answers(ctx context.Context, ar AnswerRequest) (*AnswerResponse, error) { + var b, err = json.Marshal(ar) if err != nil { - return + return nil, err } - req, err := http.NewRequest("POST", c.fullURL("/answers"), bytes.NewBuffer(reqBytes)) + var req *http.Request + req, err = http.NewRequest("POST", c.fullURL("/answers"), bytes.NewBuffer(b)) if err != nil { - return + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &response) - return + + var resp *AnswerResponse + if err = c.sendRequest(req, resp); err != nil { + return nil, err + } + + return resp, err } diff --git a/completion.go b/completion.go index cbc5645..37c1845 100644 --- a/completion.go +++ b/completion.go @@ -7,52 +7,86 @@ import ( "net/http" ) -// GPT3 Defines the models provided by OpenAI to use when generating -// completions from OpenAI. -// GPT3 Models are designed for text-based tasks. For code-specific -// tasks, please refer to the Codex series of models. -const ( - GPT3TextDavinci003 = "text-davinci-003" - GPT3TextDavinci002 = "text-davinci-002" - GPT3TextCurie001 = "text-curie-001" - GPT3TextBabbage001 = "text-babbage-001" - GPT3TextAda001 = "text-ada-001" - GPT3TextDavinci001 = "text-davinci-001" - GPT3DavinciInstructBeta = "davinci-instruct-beta" - GPT3Davinci = "davinci" - GPT3CurieInstructBeta = "curie-instruct-beta" - GPT3Curie = "curie" - GPT3Ada = "ada" - GPT3Babbage = "babbage" -) - -// Codex Defines the models provided by OpenAI. -// These models are designed for code-specific tasks, and use -// a different tokenizer which optimizes for whitespace. -const ( - CodexCodeDavinci002 = "code-davinci-002" - CodexCodeCushman001 = "code-cushman-001" - CodexCodeDavinci001 = "code-davinci-001" -) - -// CompletionRequest represents a request structure for completion API. +// CompletionRequest represents a request structure for Completion API. type CompletionRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt,omitempty"` - Suffix string `json:"suffix,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Stream bool `json:"stream,omitempty"` - LogProbs int `json:"logprobs,omitempty"` - Echo bool `json:"echo,omitempty"` - Stop []string `json:"stop,omitempty"` - PresencePenalty float32 `json:"presence_penalty,omitempty"` - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - BestOf int `json:"best_of,omitempty"` - LogitBias map[string]int `json:"logit_bias,omitempty"` - User string `json:"user,omitempty"` + // Model specifies the ID of the model to use. + // See more here: https://beta.openai.com/docs/models/overview + Model string `json:"model"` + // Prompt specifies the prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, + // or array of token arrays. Note that <|endoftext|> is the document separator that the model sees during + // training, so if a prompt is not specified the model will generate as if from the beginning of a new document. + // Defaults to <|endoftext|>. + Prompt *string `json:"prompt,omitempty"` + // Suffix specifies the suffix that comes after a completion of inserted text. + // Defaults to null. + Suffix *string `json:"suffix,omitempty"` + // MaxTokens specifies the maximum number of tokens to generate in the completion. The token count of your prompt plus + // max_tokens cannot exceed the model's context length. Most models have a context length of 2048 tokens (except + // for the newest models, which support 4096). + // Defaults to 16. + MaxTokens *int `json:"max_tokens,omitempty"` + // Temperature specifies what sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative + // applications, and 0 (argmax sampling) for ones with a well-defined answer. OpenAI generally recommends altering + // this or top_p but not both. + // More on sampling temperature: https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277 + // Defaults to 1. + Temperature *float32 `json:"temperature,omitempty"` + // TopP specifies an alternative to sampling with temperature, called nucleus sampling, where the model considers + // the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + // probability mass are considered. OpenAI generally recommends altering this or temperature but not both. + // Defaults to 1. + TopP *float32 `json:"top_p,omitempty"` + // N specifies how many completions to generate for each prompt. + // Note: Because this parameter generates many completions, it can quickly consume your token quota. Use carefully + // and ensure that you have reasonable settings for max_tokens and stop. + // Defaults to 1. + N *int `json:"n,omitempty"` + // Steam specifies Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent + // events as they become available, with the stream terminated by a data: [DONE] message. + // Defaults to false. + Stream bool `json:"stream,omitempty"` + // LogProbs specifies to include the log probabilities on the logprobs most likely tokens, as well the chosen + // tokens. For example, if logprobs is 5, the API will return a list of the 5 most likely tokens. The API will + // always return the logprob of the sampled token, so there may be up to logprobs+1 elements in the response. + // The maximum value for logprobs is 5. + // Defaults to null. + LogProbs *int `json:"logprobs,omitempty"` + // Echo specifies to echo back the prompt in addition to the completion. + // Defaults to false. + Echo bool `json:"echo,omitempty"` + // Stop specifies up to 4 sequences where the API will stop generating further tokens. The returned text will not + // contain the stop sequence. + Stop []string `json:"stop,omitempty"` + // PresencePenalty can be a number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + // appear in the text so far, increasing the model's likelihood to talk about new topics. + // Defaults to 0. + PresencePenalty float32 `json:"presence_penalty,omitempty"` + // FrequencyPenalty can be a number between -2.0 and 2.0. Positive values penalize new tokens based on their + // existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + // Defaults to 0. + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + // Generates best_of completions server-side and returns the "best" (the one with the highest log probability per + // token). Results cannot be streamed. When used with n, best_of controls the number of candidate completions and n + // specifies how many to return – best_of must be greater than n. Note: Because this parameter generates many + // completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings + // for max_tokens and stop. + // Defaults to 1. + BestOf *int `json:"best_of,omitempty"` + // LogitBias modifies the likelihood of specified tokens appearing in the completion. Accepts a json object that + // maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. + // Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will + // vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like + // -100 or 100 should result in a ban or exclusive selection of the relevant token. + // As an example, you can pass {"50256": -100} to prevent the <|endoftext|> token from being generated. + // + // You can use this tokenizer tool to convert text to token IDs: + // https://beta.openai.com/tokenizer + // + // Defaults to null. + LogitBias map[string]int `json:"logit_bias,omitempty"` + // User is a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + // See more here: https://beta.openai.com/docs/guides/safety-best-practices/end-user-ids + User string `json:"user,omitempty"` } // CompletionChoice represents one of possible completions. @@ -86,23 +120,24 @@ type CompletionResponse struct { // // If using a fine-tuned model, simply provide the model's ID in the CompletionRequest object, // and the server will use the model's parameters to generate the completion. -func (c *Client) CreateCompletion( - ctx context.Context, - request CompletionRequest, -) (response CompletionResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) +func (c *Client) CreateCompletion(ctx context.Context, cr CompletionRequest) (*CompletionResponse, error) { + var b, err = json.Marshal(cr) if err != nil { - return + return nil, err } urlSuffix := "/completions" - req, err := http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) + var req *http.Request + req, err = http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(b)) if err != nil { - return + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &response) - return + + var resp *CompletionResponse + if err = c.sendRequest(req, resp); err != nil { + return nil, err + } + + return resp, nil } diff --git a/embeddings.go b/embeddings.go index 3518cd9..f811337 100644 --- a/embeddings.go +++ b/embeddings.go @@ -4,97 +4,10 @@ import ( "bytes" "context" "encoding/json" + "github.com/fabiustech/openai/models" "net/http" ) -// EmbeddingModel enumerates the models which can be used -// to generate Embedding vectors. -type EmbeddingModel int - -// String implements the fmt.Stringer interface. -func (e EmbeddingModel) String() string { - return enumToString[e] -} - -// MarshalText implements the encoding.TextMarshaler interface. -func (e EmbeddingModel) MarshalText() ([]byte, error) { - return []byte(e.String()), nil -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -// On unrecognized value, it sets |e| to Unknown. -func (e *EmbeddingModel) UnmarshalText(b []byte) error { - if val, ok := stringToEnum[(string(b))]; ok { - *e = val - return nil - } - - *e = Unknown - - return nil -} - -const ( - Unknown EmbeddingModel = iota - AdaSimilarity - BabbageSimilarity - CurieSimilarity - DavinciSimilarity - AdaSearchDocument - AdaSearchQuery - BabbageSearchDocument - BabbageSearchQuery - CurieSearchDocument - CurieSearchQuery - DavinciSearchDocument - DavinciSearchQuery - AdaCodeSearchCode - AdaCodeSearchText - BabbageCodeSearchCode - BabbageCodeSearchText - AdaEmbeddingV2 -) - -var enumToString = map[EmbeddingModel]string{ - AdaSimilarity: "text-similarity-ada-001", - BabbageSimilarity: "text-similarity-babbage-001", - CurieSimilarity: "text-similarity-curie-001", - DavinciSimilarity: "text-similarity-davinci-001", - AdaSearchDocument: "text-search-ada-doc-001", - AdaSearchQuery: "text-search-ada-query-001", - BabbageSearchDocument: "text-search-babbage-doc-001", - BabbageSearchQuery: "text-search-babbage-query-001", - CurieSearchDocument: "text-search-curie-doc-001", - CurieSearchQuery: "text-search-curie-query-001", - DavinciSearchDocument: "text-search-davinci-doc-001", - DavinciSearchQuery: "text-search-davinci-query-001", - AdaCodeSearchCode: "code-search-ada-code-001", - AdaCodeSearchText: "code-search-ada-text-001", - BabbageCodeSearchCode: "code-search-babbage-code-001", - BabbageCodeSearchText: "code-search-babbage-text-001", - AdaEmbeddingV2: "text-embedding-ada-002", -} - -var stringToEnum = map[string]EmbeddingModel{ - "text-similarity-ada-001": AdaSimilarity, - "text-similarity-babbage-001": BabbageSimilarity, - "text-similarity-curie-001": CurieSimilarity, - "text-similarity-davinci-001": DavinciSimilarity, - "text-search-ada-doc-001": AdaSearchDocument, - "text-search-ada-query-001": AdaSearchQuery, - "text-search-babbage-doc-001": BabbageSearchDocument, - "text-search-babbage-query-001": BabbageSearchQuery, - "text-search-curie-doc-001": CurieSearchDocument, - "text-search-curie-query-001": CurieSearchQuery, - "text-search-davinci-doc-001": DavinciSearchDocument, - "text-search-davinci-query-001": DavinciSearchQuery, - "code-search-ada-code-001": AdaCodeSearchCode, - "code-search-ada-text-001": AdaCodeSearchText, - "code-search-babbage-code-001": BabbageCodeSearchCode, - "code-search-babbage-text-001": BabbageCodeSearchText, - "text-embedding-ada-002": AdaEmbeddingV2, -} - // Embedding is a special format of data representation that can be easily utilized by machine // learning models and algorithms. The embedding is an information dense representation of the // semantic meaning of a piece of text. Each embedding is a vector of floating point numbers, @@ -109,10 +22,10 @@ type Embedding struct { // EmbeddingResponse is the response from a Create embeddings request. type EmbeddingResponse struct { - Object string `json:"object"` - Data []Embedding `json:"data"` - Model EmbeddingModel `json:"model"` - Usage Usage `json:"usage"` + Object string `json:"object"` + Data []Embedding `json:"data"` + Model models.Embedding `json:"model"` + Usage Usage `json:"usage"` } // EmbeddingRequest is the input to a Create embeddings request. @@ -126,7 +39,7 @@ type EmbeddingRequest struct { Input []string `json:"input"` // ID of the model to use. You can use the List models API to see all of your available models, // or see our Model overview for descriptions of them. - Model EmbeddingModel `json:"model"` + Model models.Embedding `json:"model"` // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. User string `json:"user"` } diff --git a/models/completions.go b/models/completions.go new file mode 100644 index 0000000..a5d569f --- /dev/null +++ b/models/completions.go @@ -0,0 +1,42 @@ +package models + +type Completion int + +const ( + TextDavinci003 Completion = iota + TextDavinci002 + TextCurie001 + TextBabbage001 + TextAda001 + TextDavinci001 + DavinciInstructBeta + Davinci + CurieInstructBeta + Curie + Ada + Babbage +) + +var completionToString = map[Completion]string{ + TextDavinci003: "text-davinci-003", + TextDavinci002: "text-davinci-002", + TextCurie001: "text-curie-001", + TextBabbage001: "text-babbage-001", + TextAda001: "text-ada-001", + TextDavinci001: "text-davinci-001", + DavinciInstructBeta: "davinci-instruct-beta", + Davinci: "davinci", + CurieInstructBeta: "curie-instruct-beta", + Curie: "curie", + Ada: "ada", + Babbage: "babbage", +} + +// Codex Defines the models provided by OpenAI. +// These models are designed for code-specific tasks, and use +// a different tokenizer which optimizes for whitespace. +const ( + CodexCodeDavinci002 = "code-davinci-002" + CodexCodeCushman001 = "code-cushman-001" + CodexCodeDavinci001 = "code-davinci-001" +) diff --git a/models/embeddings.go b/models/embeddings.go new file mode 100644 index 0000000..2ed0a27 --- /dev/null +++ b/models/embeddings.go @@ -0,0 +1,89 @@ +package models + +// Embedding enumerates the models which can be used +// to generate Embedding vectors. +type Embedding int + +// String implements the fmt.Stringer interface. +func (e Embedding) String() string { + return enumToString[e] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (e Embedding) MarshalText() ([]byte, error) { + return []byte(e.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (e *Embedding) UnmarshalText(b []byte) error { + if val, ok := stringToEnum[(string(b))]; ok { + *e = val + return nil + } + + *e = Unknown + + return nil +} + +const ( + Unknown Embedding = iota + AdaSimilarity + BabbageSimilarity + CurieSimilarity + DavinciSimilarity + AdaSearchDocument + AdaSearchQuery + BabbageSearchDocument + BabbageSearchQuery + CurieSearchDocument + CurieSearchQuery + DavinciSearchDocument + DavinciSearchQuery + AdaCodeSearchCode + AdaCodeSearchText + BabbageCodeSearchCode + BabbageCodeSearchText + AdaEmbeddingV2 +) + +var enumToString = map[Embedding]string{ + AdaSimilarity: "text-similarity-ada-001", + BabbageSimilarity: "text-similarity-babbage-001", + CurieSimilarity: "text-similarity-curie-001", + DavinciSimilarity: "text-similarity-davinci-001", + AdaSearchDocument: "text-search-ada-doc-001", + AdaSearchQuery: "text-search-ada-query-001", + BabbageSearchDocument: "text-search-babbage-doc-001", + BabbageSearchQuery: "text-search-babbage-query-001", + CurieSearchDocument: "text-search-curie-doc-001", + CurieSearchQuery: "text-search-curie-query-001", + DavinciSearchDocument: "text-search-davinci-doc-001", + DavinciSearchQuery: "text-search-davinci-query-001", + AdaCodeSearchCode: "code-search-ada-code-001", + AdaCodeSearchText: "code-search-ada-text-001", + BabbageCodeSearchCode: "code-search-babbage-code-001", + BabbageCodeSearchText: "code-search-babbage-text-001", + AdaEmbeddingV2: "text-embedding-ada-002", +} + +var stringToEnum = map[string]Embedding{ + "text-similarity-ada-001": AdaSimilarity, + "text-similarity-babbage-001": BabbageSimilarity, + "text-similarity-curie-001": CurieSimilarity, + "text-similarity-davinci-001": DavinciSimilarity, + "text-search-ada-doc-001": AdaSearchDocument, + "text-search-ada-query-001": AdaSearchQuery, + "text-search-babbage-doc-001": BabbageSearchDocument, + "text-search-babbage-query-001": BabbageSearchQuery, + "text-search-curie-doc-001": CurieSearchDocument, + "text-search-curie-query-001": CurieSearchQuery, + "text-search-davinci-doc-001": DavinciSearchDocument, + "text-search-davinci-query-001": DavinciSearchQuery, + "code-search-ada-code-001": AdaCodeSearchCode, + "code-search-ada-text-001": AdaCodeSearchText, + "code-search-babbage-code-001": BabbageCodeSearchCode, + "code-search-babbage-text-001": BabbageCodeSearchText, + "text-embedding-ada-002": AdaEmbeddingV2, +} diff --git a/models/enum.go b/models/enum.go new file mode 100644 index 0000000..2640e7f --- /dev/null +++ b/models/enum.go @@ -0,0 +1 @@ +package models From 3c8ebfd187b077ea581d1e5be05604c61df02f27 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Sat, 14 Jan 2023 12:17:26 -0800 Subject: [PATCH 05/30] good progress --- answers.go | 55 ------------ api.go | 207 ++++++++++++++++++++++++++++++++++--------- completion.go | 41 ++++----- edits.go | 58 ++++++------ embeddings.go | 23 ++--- engines.go | 48 ++++++---- error.go | 30 +++++-- files.go | 118 +++++++----------------- image.go | 32 ++++--- models/embeddings.go | 17 ++++ 10 files changed, 343 insertions(+), 286 deletions(-) delete mode 100644 answers.go diff --git a/answers.go b/answers.go deleted file mode 100644 index bd2b463..0000000 --- a/answers.go +++ /dev/null @@ -1,55 +0,0 @@ -package openai - -import ( - "bytes" - "context" - "encoding/json" - "net/http" -) - -type AnswerRequest struct { - Documents []string `json:"documents,omitempty"` - File string `json:"file,omitempty"` - Question string `json:"question"` - SearchModel string `json:"search_model,omitempty"` - Model string `json:"model"` - ExamplesContext string `json:"examples_context"` - Examples [][]string `json:"examples"` - MaxTokens int `json:"max_tokens,omitempty"` - Stop []string `json:"stop,omitempty"` - Temperature *float64 `json:"temperature,omitempty"` -} - -type AnswerResponse struct { - Answers []string `json:"answers"` - Completion string `json:"completion"` - Model string `json:"model"` - Object string `json:"object"` - SearchModel string `json:"search_model"` - SelectedDocuments []struct { - Document int `json:"document"` - Text string `json:"text"` - } `json:"selected_documents"` -} - -// Answers ... -func (c *Client) Answers(ctx context.Context, ar AnswerRequest) (*AnswerResponse, error) { - var b, err = json.Marshal(ar) - if err != nil { - return nil, err - } - - var req *http.Request - req, err = http.NewRequest("POST", c.fullURL("/answers"), bytes.NewBuffer(b)) - if err != nil { - return nil, err - } - req = req.WithContext(ctx) - - var resp *AnswerResponse - if err = c.sendRequest(req, resp); err != nil { - return nil, err - } - - return resp, err -} diff --git a/api.go b/api.go index c54b3b7..7ebd172 100644 --- a/api.go +++ b/api.go @@ -1,85 +1,206 @@ package openai import ( + "bytes" + "context" "encoding/json" "fmt" + "io" + "mime/multipart" "net/http" + "net/url" + "os" + "path" + "strings" ) -const apiURLv1 = "https://api.openai.com/v1" +const ( + scheme = "https" + host = "api.openai.com" + basePath = "vi" +) -func newTransport() *http.Client { - return &http.Client{} +func reqURL(route string) string { + var u = &url.URL{ + Scheme: scheme, + Host: host, + Path: path.Join(basePath, route), + } + return u.String() } // Client is OpenAI GPT-3 API client. type Client struct { - BaseURL string - HTTPClient *http.Client - authToken string - idOrg string + token string + orgID *string } // NewClient creates new OpenAI API client. -func NewClient(authToken string) *Client { +func NewClient(token string) *Client { return &Client{ - BaseURL: apiURLv1, - HTTPClient: newTransport(), - authToken: authToken, - idOrg: "", + token: token, } } -// NewOrgClient creates new OpenAI API client for specified Organization ID. -func NewOrgClient(authToken, org string) *Client { +// NewClientWithOrg creates new OpenAI API client for specified Organization ID. +func NewClientWithOrg(token, org string) *Client { return &Client{ - BaseURL: apiURLv1, - HTTPClient: newTransport(), - authToken: authToken, - idOrg: org, + token: token, + orgID: &org, } } -func (c *Client) sendRequest(req *http.Request, v interface{}) error { - req.Header.Set("Accept", "application/json; charset=utf-8") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.authToken)) +func (c *Client) post(ctx context.Context, path string, payload any) ([]byte, error) { + var b, err = json.Marshal(payload) + if err != nil { + return nil, err + } + + var req *http.Request + req, err = http.NewRequestWithContext(ctx, "POST", reqURL(path), bytes.NewBuffer(b)) + if err != nil { + return nil, err + } - // Check whether Content-Type is already set, Upload Files API requires - // Content-Type == multipart/form-data - contentType := req.Header.Get("Content-Type") - if contentType == "" { + switch payload.(type) { + case FileRequest: + req.Header.Set("Content-Type", "") // TODO + default: req.Header.Set("Content-Type", "application/json; charset=utf-8") } - if len(c.idOrg) > 0 { - req.Header.Set("OpenAI-Organization", c.idOrg) + if c.orgID != nil { + req.Header.Set("OpenAI-Organization", *c.orgID) } - res, err := c.HTTPClient.Do(req) + var resp *http.Response + resp, err = http.DefaultClient.Do(req) if err != nil { - return err + return nil, err } + defer resp.Body.Close() - defer res.Body.Close() + if err = interpretResponse(resp); err != nil { + return nil, err + } - if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusBadRequest { - var errRes ErrorResponse - err = json.NewDecoder(res.Body).Decode(&errRes) - if err != nil || errRes.Error == nil { - return fmt.Errorf("error, status code: %d", res.StatusCode) - } - return fmt.Errorf("error, status code: %d, message: %s", res.StatusCode, errRes.Error.Message) + return io.ReadAll(resp.Body) +} + +// TODO: improve this. +func (c *Client) postFile(ctx context.Context, fr *FileRequest) ([]byte, error) { + var b bytes.Buffer + w := multipart.NewWriter(&b) + + var pw, err = w.CreateFormField("purpose") + if err != nil { + return nil, err } - if v != nil { - if err = json.NewDecoder(res.Body).Decode(&v); err != nil { - return err + _, err = io.Copy(pw, strings.NewReader(fr.Purpose)) + if err != nil { + return nil, err + } + + var fw io.Writer + fw, err = w.CreateFormFile("file", fr.FileName) + if err != nil { + return nil, err + } + + var fileData io.ReadCloser + if isURL(fr.FilePath) { + var remoteFile *http.Response + remoteFile, err = http.Get(fr.FilePath) + if err != nil { + return nil, err + } + + defer remoteFile.Body.Close() + + // Check server response + if remoteFile.StatusCode != http.StatusOK { + return nil, fmt.Errorf("error, status code: %d, message: failed to fetch file", remoteFile.StatusCode) + } + + fileData = remoteFile.Body + } else { + fileData, err = os.Open(fr.FilePath) + if err != nil { + return nil, err } } - return nil + _, err = io.Copy(fw, fileData) + if err != nil { + return nil, err + } + + w.Close() + + var req *http.Request + req, err = http.NewRequestWithContext(ctx, "POST", reqURL(routeFiles), &b) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", w.FormDataContentType()) + + var resp *http.Response + resp, err = http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if err = interpretResponse(resp); err != nil { + return nil, err + } + + return io.ReadAll(resp.Body) } -func (c *Client) fullURL(suffix string) string { - return fmt.Sprintf("%s%s", c.BaseURL, suffix) +func (c *Client) get(ctx context.Context, path string) ([]byte, error) { + var req, err = http.NewRequestWithContext(ctx, "POST", reqURL(path), nil) + if err != nil { + return nil, err + } + + if c.orgID != nil { + req.Header.Set("OpenAI-Organization", *c.orgID) + } + + var resp *http.Response + resp, err = http.DefaultClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if err = interpretResponse(resp); err != nil { + return nil, err + } + + return io.ReadAll(resp.Body) +} + +func (c *Client) delete(ctx context.Context, path string) error { + var req, err = http.NewRequestWithContext(ctx, "DELETE", reqURL(path), nil) + if err != nil { + return err + } + + var resp *http.Response + resp, err = http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + return interpretResponse(resp) +} + +// TODO: implement. +func interpretResponse(resp *http.Response) error { + return nil } diff --git a/completion.go b/completion.go index 37c1845..802fb75 100644 --- a/completion.go +++ b/completion.go @@ -1,17 +1,16 @@ package openai import ( - "bytes" "context" "encoding/json" - "net/http" + "github.com/fabiustech/openai/models" ) // CompletionRequest represents a request structure for Completion API. type CompletionRequest struct { // Model specifies the ID of the model to use. // See more here: https://beta.openai.com/docs/models/overview - Model string `json:"model"` + Model models.Completion `json:"model"` // Prompt specifies the prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, // or array of token arrays. Note that <|endoftext|> is the document separator that the model sees during // training, so if a prompt is not specified the model will generate as if from the beginning of a new document. @@ -91,10 +90,10 @@ type CompletionRequest struct { // CompletionChoice represents one of possible completions. type CompletionChoice struct { - Text string `json:"text"` - Index int `json:"index"` - FinishReason string `json:"finish_reason"` - LogProbs LogprobResult `json:"logprobs"` + Text string `json:"text"` + Index int `json:"index"` + FinishReason string `json:"finish_reason"` + LogProbs *LogprobResult `json:"logprobs"` } // LogprobResult represents logprob result of Choice. @@ -107,35 +106,29 @@ type LogprobResult struct { // CompletionResponse represents a response structure for completion API. type CompletionResponse struct { - ID string `json:"id"` - Object string `json:"object"` - Created uint64 `json:"created"` - Model string `json:"model"` - Choices []CompletionChoice `json:"choices"` - Usage Usage `json:"usage"` + ID string `json:"id"` + Object string `json:"object"` + Created uint64 `json:"created"` + Model models.Completion `json:"model"` + Choices []*CompletionChoice `json:"choices"` + Usage *Usage `json:"usage"` } +const routeCompletions = "completions" + // CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well // as, if requested, the probabilities over each alternative token at each position. // // If using a fine-tuned model, simply provide the model's ID in the CompletionRequest object, // and the server will use the model's parameters to generate the completion. -func (c *Client) CreateCompletion(ctx context.Context, cr CompletionRequest) (*CompletionResponse, error) { - var b, err = json.Marshal(cr) - if err != nil { - return nil, err - } - - urlSuffix := "/completions" - var req *http.Request - req, err = http.NewRequest("POST", c.fullURL(urlSuffix), bytes.NewBuffer(b)) +func (c *Client) CreateCompletion(ctx context.Context, cr *CompletionRequest) (*CompletionResponse, error) { + var b, err = c.post(ctx, routeCompletions, cr) if err != nil { return nil, err } - req = req.WithContext(ctx) var resp *CompletionResponse - if err = c.sendRequest(req, resp); err != nil { + if err = json.Unmarshal(b, resp); err != nil { return nil, err } diff --git a/edits.go b/edits.go index a63875a..17e3c48 100644 --- a/edits.go +++ b/edits.go @@ -1,20 +1,32 @@ package openai import ( - "bytes" "context" "encoding/json" - "net/http" ) // EditsRequest represents a request structure for Edits API. type EditsRequest struct { - Model *string `json:"model,omitempty"` - Input string `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - N int `json:"n,omitempty"` - Temperature float32 `json:"temperature,omitempty"` - TopP float32 `json:"top_p,omitempty"` + Model string `json:"model"` + // Input is the input text to use as a starting point for the edit. + // Defaults to "". + Input *string `json:"input,omitempty"` + // Instruction is the instruction that tells the model how to edit the prompt. + Instruction string `json:"instruction,omitempty"` + // N specifies how many edits to generate for the input and instruction. + // Defaults to 1. + N *int `json:"n,omitempty"` + // Temperature specifies what sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative + // applications, and 0 (argmax sampling) for ones with a well-defined answer. OpenAI generally recommends altering + // this or top_p but not both. + // More on sampling temperature: https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277 + // Defaults to 1. + Temperature *float32 `json:"temperature,omitempty"` + // TopP specifies an alternative to sampling with temperature, called nucleus sampling, where the model considers + // the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + // probability mass are considered. OpenAI generally recommends altering this or temperature but not both. + // Defaults to 1. + TopP *float32 `json:"top_p,omitempty"` } // EditsChoice represents one of possible edits. @@ -25,26 +37,22 @@ type EditsChoice struct { // EditsResponse represents a response structure for Edits API. type EditsResponse struct { - Object string `json:"object"` - Created uint64 `json:"created"` - Usage Usage `json:"usage"` - Choices []EditsChoice `json:"choices"` + Object string `json:"object"` // "edit" + Created uint64 `json:"created"` + Usage *Usage `json:"usage"` + Choices []*EditsChoice `json:"choices"` } -// Perform an API call to the Edits endpoint. -func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) - if err != nil { - return - } +const editsRoute = "edits" + +// Edits ... +func (c *Client) Edits(ctx context.Context, er *EditsRequest) (*EditsResponse, error) { + var b, err = c.post(ctx, editsRoute, er) - req, err := http.NewRequest("POST", c.fullURL("/edits"), bytes.NewBuffer(reqBytes)) - if err != nil { - return + var resp *EditsResponse + if err = json.Unmarshal(b, resp); err != nil { + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &response) - return + return resp, nil } diff --git a/embeddings.go b/embeddings.go index f811337..302b271 100644 --- a/embeddings.go +++ b/embeddings.go @@ -1,11 +1,9 @@ package openai import ( - "bytes" "context" "encoding/json" "github.com/fabiustech/openai/models" - "net/http" ) // Embedding is a special format of data representation that can be easily utilized by machine @@ -44,23 +42,20 @@ type EmbeddingRequest struct { User string `json:"user"` } +const embeddingsRoute = "embeddings" + // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create -func (c *Client) CreateEmbeddings(ctx context.Context, request EmbeddingRequest) (resp EmbeddingResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) +func (c *Client) CreateEmbeddings(ctx context.Context, request *EmbeddingRequest) (*EmbeddingResponse, error) { + var b, err = c.post(ctx, embeddingsRoute, request) if err != nil { - return + return nil, err } - urlSuffix := "/embeddings" - req, err := http.NewRequest(http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) - if err != nil { - return + var resp *EmbeddingResponse + if err = json.Unmarshal(b, resp); err != nil { + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &resp) - - return + return resp, nil } diff --git a/engines.go b/engines.go index 350241b..d1aa5f3 100644 --- a/engines.go +++ b/engines.go @@ -2,8 +2,8 @@ package openai import ( "context" - "fmt" - "net/http" + "encoding/json" + "path" ) // Engine struct represents engine from OpenAPI API. @@ -16,35 +16,45 @@ type Engine struct { // EnginesList is a list of engines. type EnginesList struct { - Engines []Engine `json:"data"` + Engines []*Engine `json:"data"` } +const routeEngines = "engines" + // ListEngines Lists the currently available engines, and provides basic // information about each option such as the owner and availability. -func (c *Client) ListEngines(ctx context.Context) (engines EnginesList, err error) { - req, err := http.NewRequest("GET", c.fullURL("/engines"), nil) +// +// Deprecated: Please use their replacement, Models, instead. +// https://beta.openai.com/docs/api-reference/models +func (c *Client) ListEngines(ctx context.Context) (*EnginesList, error) { + var b, err = c.get(ctx, routeEngines) if err != nil { - return + return nil, err + } + + var el *EnginesList + if err = json.Unmarshal(b, el); err != nil { + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &engines) - return + return el, nil } // GetEngine Retrieves an engine instance, providing basic information about // the engine such as the owner and availability. -func (c *Client) GetEngine( - ctx context.Context, - engineID string, -) (engine Engine, err error) { - urlSuffix := fmt.Sprintf("/engines/%s", engineID) - req, err := http.NewRequest("GET", c.fullURL(urlSuffix), nil) +// +// Deprecated: Please use their replacement, Models, instead. +// https://beta.openai.com/docs/api-reference/models +func (c *Client) GetEngine(ctx context.Context, id string) (*Engine, error) { + var b, err = c.get(ctx, path.Join(routeEngines, id)) if err != nil { - return + return nil, err + } + + var e *Engine + if err = json.Unmarshal(b, e); err != nil { + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &engine) - return + return e, nil } diff --git a/error.go b/error.go index 3f65b66..7d34401 100644 --- a/error.go +++ b/error.go @@ -1,10 +1,28 @@ package openai +import ( + "fmt" + "net/http" +) + type ErrorResponse struct { - Error *struct { - Code *int `json:"code,omitempty"` - Message string `json:"message"` - Param *string `json:"param,omitempty"` - Type string `json:"type"` - } `json:"error,omitempty"` + Error *Error `json:"error,omitempty"` +} + +type Error struct { + Code int `json:"code,omitempty"` + Message string `json:"message"` + Param *string `json:"param,omitempty"` + Type string `json:"type"` +} + +func (e *Error) Error() string { + return fmt.Sprintf("Code: %v, Message: %s, Type: %s, Param: %v", e.Code, e.Message, e.Type, e.Param) +} + +func (e *Error) Retryable() bool { + if e.Code >= http.StatusInternalServerError { + return true + } + return e.Code == http.StatusTooManyRequests } diff --git a/files.go b/files.go index 0dfb41b..582ae98 100644 --- a/files.go +++ b/files.go @@ -1,17 +1,13 @@ package openai import ( - "bytes" "context" - "fmt" - "io" - "mime/multipart" - "net/http" + "encoding/json" "net/url" - "os" - "strings" + "path" ) +// FileRequest ... type FileRequest struct { FileName string `json:"file"` FilePath string `json:"-"` @@ -52,105 +48,55 @@ func isURL(path string) bool { // CreateFile uploads a jsonl file to GPT3 // FilePath can be either a local file path or a URL. -func (c *Client) CreateFile(ctx context.Context, request FileRequest) (file File, err error) { - var b bytes.Buffer - w := multipart.NewWriter(&b) - - var fw, pw io.Writer - pw, err = w.CreateFormField("purpose") - if err != nil { - return - } - - _, err = io.Copy(pw, strings.NewReader(request.Purpose)) - if err != nil { - return - } - - fw, err = w.CreateFormFile("file", request.FileName) - if err != nil { - return - } - - var fileData io.ReadCloser - if isURL(request.FilePath) { - var remoteFile *http.Response - remoteFile, err = http.Get(request.FilePath) - if err != nil { - return - } - - defer remoteFile.Body.Close() - - // Check server response - if remoteFile.StatusCode != http.StatusOK { - err = fmt.Errorf("error, status code: %d, message: failed to fetch file", remoteFile.StatusCode) - return - } - - fileData = remoteFile.Body - } else { - fileData, err = os.Open(request.FilePath) - if err != nil { - return - } - } - - _, err = io.Copy(fw, fileData) +func (c *Client) CreateFile(ctx context.Context, fr *FileRequest) (*File, error) { + var b, err = c.postFile(ctx, fr) if err != nil { - return + return nil, err } - w.Close() - - req, err := http.NewRequest("POST", c.fullURL("/files"), &b) - if err != nil { - return + var f *File + if err = json.Unmarshal(b, f); err != nil { + return nil, err } - req = req.WithContext(ctx) - req.Header.Set("Content-Type", w.FormDataContentType()) - - err = c.sendRequest(req, &file) - - return + return f, nil } -// DeleteFile deletes an existing file. -func (c *Client) DeleteFile(ctx context.Context, fileID string) (err error) { - req, err := http.NewRequest("DELETE", c.fullURL("/files/"+fileID), nil) - if err != nil { - return - } +const routeFiles = "files" - req = req.WithContext(ctx) - err = c.sendRequest(req, nil) - return +// DeleteFile deletes an existing file. +func (c *Client) DeleteFile(ctx context.Context, id string) error { + return c.delete(ctx, path.Join(routeFiles, id)) } // ListFiles Lists the currently available files, // and provides basic information about each file such as the file name and purpose. -func (c *Client) ListFiles(ctx context.Context) (files FilesList, err error) { - req, err := http.NewRequest("GET", c.fullURL("/files"), nil) +func (c *Client) ListFiles(ctx context.Context) (*FilesList, error) { + var b, err = c.get(ctx, routeFiles) if err != nil { - return + return nil, err + } + + var fl *FilesList + if err = json.Unmarshal(b, fl); err != nil { + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &files) - return + return fl, nil } // GetFile Retrieves a file instance, providing basic information about the file // such as the file name and purpose. -func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err error) { - urlSuffix := fmt.Sprintf("/files/%s", fileID) - req, err := http.NewRequest("GET", c.fullURL(urlSuffix), nil) +func (c *Client) GetFile(ctx context.Context, id string) (*File, error) { + var b, err = c.get(ctx, path.Join(routeFiles, id)) if err != nil { - return + return nil, err + } + + var f *File + if err = json.Unmarshal(b, f); err != nil { + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &file) - return + return f, nil } diff --git a/image.go b/image.go index 03292d6..de026b0 100644 --- a/image.go +++ b/image.go @@ -1,13 +1,12 @@ package openai import ( - "bytes" "context" "encoding/json" - "net/http" ) // Image sizes defined by the OpenAI API. +// TODO: make enum. const ( CreateImageSize256x256 = "256x256" CreateImageSize512x512 = "512x512" @@ -34,27 +33,32 @@ type ImageResponse struct { Data []ImageResponseDataInner `json:"data,omitempty"` } -// ImageResponseData represents a response data structure for image API. +// ImageResponseDataInner represents a response data structure for image API. type ImageResponseDataInner struct { URL string `json:"url,omitempty"` B64JSON string `json:"b64_json,omitempty"` } +const ( + routeGenerations = "images/generations" + routeEdits = "images/edits" + routeVariations = "images/variations" +) + // CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. -func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) +func (c *Client) CreateImage(ctx context.Context, ir *ImageRequest) (*ImageResponse, error) { + var b, err = c.post(ctx, routeGenerations, ir) if err != nil { - return + return nil, err } - urlSuffix := "/images/generations" - req, err := http.NewRequest(http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) - if err != nil { - return + var resp *ImageResponse + if err = json.Unmarshal(b, resp); err != nil { + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &response) - return + return resp, nil } + +// TODO: ImageEdit +// TODO: ImageVariation diff --git a/models/embeddings.go b/models/embeddings.go index 2ed0a27..a09baa1 100644 --- a/models/embeddings.go +++ b/models/embeddings.go @@ -29,22 +29,39 @@ func (e *Embedding) UnmarshalText(b []byte) error { const ( Unknown Embedding = iota + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. AdaSimilarity + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. BabbageSimilarity + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. CurieSimilarity + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. DavinciSimilarity + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. AdaSearchDocument + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. AdaSearchQuery + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. BabbageSearchDocument + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. BabbageSearchQuery + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. CurieSearchDocument + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. CurieSearchQuery + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. DavinciSearchDocument + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. DavinciSearchQuery + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. AdaCodeSearchCode + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. AdaCodeSearchText + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. BabbageCodeSearchCode + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. BabbageCodeSearchText + AdaEmbeddingV2 ) From a3c825f2e1edc7697ad4d28a1d7b52d4da15ff2e Mon Sep 17 00:00:00 2001 From: Andy Day Date: Sat, 14 Jan 2023 16:24:29 -0800 Subject: [PATCH 06/30] wip --- api.go | 50 +++++++++++++++--------------- api_test.go | 12 ++++---- files.go | 6 ++-- image.go | 78 ++++++++++++++++++++++++++++++++--------------- images/formats.go | 42 +++++++++++++++++++++++++ images/sizes.go | 47 ++++++++++++++++++++++++++++ 6 files changed, 177 insertions(+), 58 deletions(-) create mode 100644 images/formats.go create mode 100644 images/sizes.go diff --git a/api.go b/api.go index 7ebd172..43e8eb1 100644 --- a/api.go +++ b/api.go @@ -90,7 +90,7 @@ func (c *Client) post(ctx context.Context, path string, payload any) ([]byte, er // TODO: improve this. func (c *Client) postFile(ctx context.Context, fr *FileRequest) ([]byte, error) { var b bytes.Buffer - w := multipart.NewWriter(&b) + var w = multipart.NewWriter(&b) var pw, err = w.CreateFormField("purpose") if err != nil { @@ -108,31 +108,14 @@ func (c *Client) postFile(ctx context.Context, fr *FileRequest) ([]byte, error) return nil, err } - var fileData io.ReadCloser - if isURL(fr.FilePath) { - var remoteFile *http.Response - remoteFile, err = http.Get(fr.FilePath) - if err != nil { - return nil, err - } - - defer remoteFile.Body.Close() - - // Check server response - if remoteFile.StatusCode != http.StatusOK { - return nil, fmt.Errorf("error, status code: %d, message: failed to fetch file", remoteFile.StatusCode) - } - - fileData = remoteFile.Body - } else { - fileData, err = os.Open(fr.FilePath) - if err != nil { - return nil, err - } + var file io.ReadCloser + file, err = readFile(fr.FilePath) + if err != nil { + return nil, err } + defer file.Close() - _, err = io.Copy(fw, fileData) - if err != nil { + if _, err = io.Copy(fw, file); err != nil { return nil, err } @@ -160,6 +143,25 @@ func (c *Client) postFile(ctx context.Context, fr *FileRequest) ([]byte, error) return io.ReadAll(resp.Body) } +func readFile(path string) (io.ReadCloser, error) { + if !isURL(path) { + return os.Open(path) + } + + var resp, err = http.Get(path) + if err != nil { + return nil, err + } + + // Check server response. + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + return nil, fmt.Errorf("error, status code: %d, message: failed to fetch file", resp.StatusCode) + } + + return resp.Body, nil +} + func (c *Client) get(ctx context.Context, path string) ([]byte, error) { var req, err = http.NewRequestWithContext(ctx, "POST", reqURL(path), nil) if err != nil { diff --git a/api_test.go b/api_test.go index c3e9cf6..1168d6b 100644 --- a/api_test.go +++ b/api_test.go @@ -269,7 +269,7 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var imageReq openai.ImageRequest + var imageReq openai.CreateImageRequest if imageReq, err = getImageBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return @@ -311,16 +311,16 @@ func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) { } // getImageBody Returns the body of the request to create a image. -func getImageBody(r *http.Request) (openai.ImageRequest, error) { - image := openai.ImageRequest{} +func getImageBody(r *http.Request) (openai.CreateImageRequest, error) { + image := openai.CreateImageRequest{} // read the request body reqBody, err := ioutil.ReadAll(r.Body) if err != nil { - return openai.ImageRequest{}, err + return openai.CreateImageRequest{}, err } err = json.Unmarshal(reqBody, &image) if err != nil { - return openai.ImageRequest{}, err + return openai.CreateImageRequest{}, err } return image, nil } @@ -345,7 +345,7 @@ func TestImages(t *testing.T) { ctx := context.Background() client.BaseURL = ts.URL + "/v1" - req := openai.ImageRequest{} + req := openai.CreateImageRequest{} req.Prompt = "Lorem ipsum" _, err = client.CreateImage(ctx, req) if err != nil { diff --git a/files.go b/files.go index 582ae98..1bd91d4 100644 --- a/files.go +++ b/files.go @@ -33,13 +33,11 @@ type FilesList struct { // isUrl is a helper function that determines whether the given FilePath // is a remote URL or a local file path. func isURL(path string) bool { - _, err := url.ParseRequestURI(path) - if err != nil { + if _, err := url.ParseRequestURI(path); err != nil { return false } - u, err := url.Parse(path) - if err != nil || u.Scheme == "" || u.Host == "" { + if u, err := url.Parse(path); err != nil || u.Scheme == "" || u.Host == "" { return false } diff --git a/image.go b/image.go index de026b0..c670c45 100644 --- a/image.go +++ b/image.go @@ -3,34 +3,38 @@ package openai import ( "context" "encoding/json" + "github.com/fabiustech/openai/images" ) -// Image sizes defined by the OpenAI API. -// TODO: make enum. -const ( - CreateImageSize256x256 = "256x256" - CreateImageSize512x512 = "512x512" - CreateImageSize1024x1024 = "1024x1024" -) +// CreateImageRequest represents the request structure for the image API. +type CreateImageRequest struct { + Prompt string `json:"prompt"` + *ImageRequestFields +} -const ( - CreateImageResponseFormatURL = "url" - CreateImageResponseFormatB64JSON = "b64_json" -) +type EditImageRequest struct { + Image string `json:"image"` + Mask *string `json:"mask,omitempty"` + Prompt string `json:"prompt"` + *ImageRequestFields +} + +type VariationImageRequest struct { + Image string `json:"image"` + *ImageRequestFields +} -// ImageRequest represents the request structure for the image API. -type ImageRequest struct { - Prompt string `json:"prompt,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - User string `json:"user,omitempty"` +type ImageRequestFields struct { + N *int `json:"n,omitempty"` + Size *images.Size `json:"size,omitempty"` + ResponseFormat *images.Format `json:"response_format,omitempty"` + User *string `json:"user,omitempty"` } // ImageResponse represents a response structure for image API. type ImageResponse struct { - Created uint64 `json:"created,omitempty"` - Data []ImageResponseDataInner `json:"data,omitempty"` + Created uint64 `json:"created,omitempty"` + Data []*ImageResponseDataInner `json:"data,omitempty"` } // ImageResponseDataInner represents a response data structure for image API. @@ -45,8 +49,8 @@ const ( routeVariations = "images/variations" ) -// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. -func (c *Client) CreateImage(ctx context.Context, ir *ImageRequest) (*ImageResponse, error) { +// CreateImage ... +func (c *Client) CreateImage(ctx context.Context, ir *CreateImageRequest) (*ImageResponse, error) { var b, err = c.post(ctx, routeGenerations, ir) if err != nil { return nil, err @@ -60,5 +64,31 @@ func (c *Client) CreateImage(ctx context.Context, ir *ImageRequest) (*ImageRespo return resp, nil } -// TODO: ImageEdit -// TODO: ImageVariation +// EditImage ... +func (c *Client) EditImage(ctx context.Context, eir *EditImageRequest) (*ImageResponse, error) { + var b, err = c.post(ctx, routeEdits, eir) + if err != nil { + return nil, err + } + + var resp *ImageResponse + if err = json.Unmarshal(b, resp); err != nil { + return nil, err + } + + return resp, nil +} + +func (c *Client) ImageVariation(ctx context.Context, vir *VariationImageRequest) (*ImageResponse, error) { + var b, err = c.post(ctx, routeVariations, vir) + if err != nil { + return nil, err + } + + var resp *ImageResponse + if err = json.Unmarshal(b, resp); err != nil { + return nil, err + } + + return resp, nil +} diff --git a/images/formats.go b/images/formats.go new file mode 100644 index 0000000..783eba2 --- /dev/null +++ b/images/formats.go @@ -0,0 +1,42 @@ +package images + +type Format int + +// String implements the fmt.Stringer interface. +func (f Format) String() string { + return formatToString[f] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (f Format) MarshalText() ([]byte, error) { + return []byte(f.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (f *Format) UnmarshalText(b []byte) error { + if val, ok := stringToFormat[(string(b))]; ok { + *f = val + return nil + } + + *f = FormatUnkown + + return nil +} + +const ( + FormatUnkown Format = iota + FormatURL + FormatB64JSON +) + +var formatToString = map[Format]string{ + FormatURL: "url", + FormatB64JSON: "b64_json", +} + +var stringToFormat = map[string]Format{ + "url": FormatURL, + "b64_json": FormatB64JSON, +} diff --git a/images/sizes.go b/images/sizes.go new file mode 100644 index 0000000..3d5c652 --- /dev/null +++ b/images/sizes.go @@ -0,0 +1,47 @@ +package images + +type Size int + +// Image sizes defined by the OpenAI API. +// TODO: make enum. +const ( + SizeInvalid Size = iota + Size256x256 + Size512x512 + Size1024x1024 +) + +// String implements the fmt.Stringer interface. +func (s Size) String() string { + return imageToString[s] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (s Size) MarshalText() ([]byte, error) { + return []byte(s.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (s *Size) UnmarshalText(b []byte) error { + if val, ok := stringToImage[(string(b))]; ok { + *s = val + return nil + } + + *s = SizeInvalid + + return nil +} + +var imageToString = map[Size]string{ + Size256x256: "256x256", + Size512x512: "512x512", + Size1024x1024: "1024x1024", +} + +var stringToImage = map[string]Size{ + "256x256": Size256x256, + "512x512": Size512x512, + "1024x1024": Size1024x1024, +} From 71cb5bb3eed5f2fcf925cc505bea234fbe22ae94 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Sat, 14 Jan 2023 16:26:18 -0800 Subject: [PATCH 07/30] . --- image.go | 1 + 1 file changed, 1 insertion(+) diff --git a/image.go b/image.go index c670c45..880df40 100644 --- a/image.go +++ b/image.go @@ -79,6 +79,7 @@ func (c *Client) EditImage(ctx context.Context, eir *EditImageRequest) (*ImageRe return resp, nil } +// ImageVariation ... func (c *Client) ImageVariation(ctx context.Context, vir *VariationImageRequest) (*ImageResponse, error) { var b, err = c.post(ctx, routeVariations, vir) if err != nil { From b988364bbebef023ad8fe41d795f380bf78cbdee Mon Sep 17 00:00:00 2001 From: Andy Day Date: Sat, 14 Jan 2023 16:27:37 -0800 Subject: [PATCH 08/30] rename --- api.go => client.go | 2 +- api_test.go => client_test.go | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename api.go => client.go (99%) rename api_test.go => client_test.go (100%) diff --git a/api.go b/client.go similarity index 99% rename from api.go rename to client.go index 43e8eb1..1c02705 100644 --- a/api.go +++ b/client.go @@ -17,7 +17,7 @@ import ( const ( scheme = "https" host = "api.openai.com" - basePath = "vi" + basePath = "v1" ) func reqURL(route string) string { diff --git a/api_test.go b/client_test.go similarity index 100% rename from api_test.go rename to client_test.go From 4c72350817cfc8833e1db290ae813026ad3d8e5f Mon Sep 17 00:00:00 2001 From: Andy Day Date: Sun, 15 Jan 2023 12:17:26 -0800 Subject: [PATCH 09/30] lots more stuff --- completion.go => completions.go | 5 +- edits.go | 5 +- embeddings.go | 2 - engines.go | 7 +- files.go | 9 +- image.go => images.go | 13 +-- images/formats.go | 20 ++-- images/sizes.go | 11 ++- models/completions.go | 158 +++++++++++++++++++++++++++----- models/edits.go | 21 +++++ models/embeddings.go | 16 +++- models/enum.go | 1 - moderation.go | 23 ++--- routes/routes.go | 39 ++++++++ 14 files changed, 255 insertions(+), 75 deletions(-) rename completion.go => completions.go (98%) rename image.go => images.go (88%) create mode 100644 models/edits.go delete mode 100644 models/enum.go create mode 100644 routes/routes.go diff --git a/completion.go b/completions.go similarity index 98% rename from completion.go rename to completions.go index 802fb75..7837c1d 100644 --- a/completion.go +++ b/completions.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "github.com/fabiustech/openai/models" + "github.com/fabiustech/openai/routes" ) // CompletionRequest represents a request structure for Completion API. @@ -114,15 +115,13 @@ type CompletionResponse struct { Usage *Usage `json:"usage"` } -const routeCompletions = "completions" - // CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well // as, if requested, the probabilities over each alternative token at each position. // // If using a fine-tuned model, simply provide the model's ID in the CompletionRequest object, // and the server will use the model's parameters to generate the completion. func (c *Client) CreateCompletion(ctx context.Context, cr *CompletionRequest) (*CompletionResponse, error) { - var b, err = c.post(ctx, routeCompletions, cr) + var b, err = c.post(ctx, routes.Completions, cr) if err != nil { return nil, err } diff --git a/edits.go b/edits.go index 17e3c48..1797cd7 100644 --- a/edits.go +++ b/edits.go @@ -3,6 +3,7 @@ package openai import ( "context" "encoding/json" + "github.com/fabiustech/openai/routes" ) // EditsRequest represents a request structure for Edits API. @@ -43,11 +44,9 @@ type EditsResponse struct { Choices []*EditsChoice `json:"choices"` } -const editsRoute = "edits" - // Edits ... func (c *Client) Edits(ctx context.Context, er *EditsRequest) (*EditsResponse, error) { - var b, err = c.post(ctx, editsRoute, er) + var b, err = c.post(ctx, routes.Edits, er) var resp *EditsResponse if err = json.Unmarshal(b, resp); err != nil { diff --git a/embeddings.go b/embeddings.go index 302b271..402c409 100644 --- a/embeddings.go +++ b/embeddings.go @@ -42,8 +42,6 @@ type EmbeddingRequest struct { User string `json:"user"` } -const embeddingsRoute = "embeddings" - // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create func (c *Client) CreateEmbeddings(ctx context.Context, request *EmbeddingRequest) (*EmbeddingResponse, error) { diff --git a/engines.go b/engines.go index d1aa5f3..09ba942 100644 --- a/engines.go +++ b/engines.go @@ -3,6 +3,7 @@ package openai import ( "context" "encoding/json" + "github.com/fabiustech/openai/routes" "path" ) @@ -19,15 +20,13 @@ type EnginesList struct { Engines []*Engine `json:"data"` } -const routeEngines = "engines" - // ListEngines Lists the currently available engines, and provides basic // information about each option such as the owner and availability. // // Deprecated: Please use their replacement, Models, instead. // https://beta.openai.com/docs/api-reference/models func (c *Client) ListEngines(ctx context.Context) (*EnginesList, error) { - var b, err = c.get(ctx, routeEngines) + var b, err = c.get(ctx, routes.Engines) if err != nil { return nil, err } @@ -46,7 +45,7 @@ func (c *Client) ListEngines(ctx context.Context) (*EnginesList, error) { // Deprecated: Please use their replacement, Models, instead. // https://beta.openai.com/docs/api-reference/models func (c *Client) GetEngine(ctx context.Context, id string) (*Engine, error) { - var b, err = c.get(ctx, path.Join(routeEngines, id)) + var b, err = c.get(ctx, path.Join(routes.Engines, id)) if err != nil { return nil, err } diff --git a/files.go b/files.go index 1bd91d4..04c2f46 100644 --- a/files.go +++ b/files.go @@ -3,6 +3,7 @@ package openai import ( "context" "encoding/json" + "github.com/fabiustech/openai/routes" "net/url" "path" ) @@ -60,17 +61,15 @@ func (c *Client) CreateFile(ctx context.Context, fr *FileRequest) (*File, error) return f, nil } -const routeFiles = "files" - // DeleteFile deletes an existing file. func (c *Client) DeleteFile(ctx context.Context, id string) error { - return c.delete(ctx, path.Join(routeFiles, id)) + return c.delete(ctx, path.Join(routes.Files, id)) } // ListFiles Lists the currently available files, // and provides basic information about each file such as the file name and purpose. func (c *Client) ListFiles(ctx context.Context) (*FilesList, error) { - var b, err = c.get(ctx, routeFiles) + var b, err = c.get(ctx, routes.Files) if err != nil { return nil, err } @@ -86,7 +85,7 @@ func (c *Client) ListFiles(ctx context.Context) (*FilesList, error) { // GetFile Retrieves a file instance, providing basic information about the file // such as the file name and purpose. func (c *Client) GetFile(ctx context.Context, id string) (*File, error) { - var b, err = c.get(ctx, path.Join(routeFiles, id)) + var b, err = c.get(ctx, path.Join(routes.Files, id)) if err != nil { return nil, err } diff --git a/image.go b/images.go similarity index 88% rename from image.go rename to images.go index 880df40..b05fd1e 100644 --- a/image.go +++ b/images.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "github.com/fabiustech/openai/images" + "github.com/fabiustech/openai/routes" ) // CreateImageRequest represents the request structure for the image API. @@ -43,15 +44,9 @@ type ImageResponseDataInner struct { B64JSON string `json:"b64_json,omitempty"` } -const ( - routeGenerations = "images/generations" - routeEdits = "images/edits" - routeVariations = "images/variations" -) - // CreateImage ... func (c *Client) CreateImage(ctx context.Context, ir *CreateImageRequest) (*ImageResponse, error) { - var b, err = c.post(ctx, routeGenerations, ir) + var b, err = c.post(ctx, routes.ImageGenerations, ir) if err != nil { return nil, err } @@ -66,7 +61,7 @@ func (c *Client) CreateImage(ctx context.Context, ir *CreateImageRequest) (*Imag // EditImage ... func (c *Client) EditImage(ctx context.Context, eir *EditImageRequest) (*ImageResponse, error) { - var b, err = c.post(ctx, routeEdits, eir) + var b, err = c.post(ctx, routes.ImageEdits, eir) if err != nil { return nil, err } @@ -81,7 +76,7 @@ func (c *Client) EditImage(ctx context.Context, eir *EditImageRequest) (*ImageRe // ImageVariation ... func (c *Client) ImageVariation(ctx context.Context, vir *VariationImageRequest) (*ImageResponse, error) { - var b, err = c.post(ctx, routeVariations, vir) + var b, err = c.post(ctx, routes.ImageVariations, vir) if err != nil { return nil, err } diff --git a/images/formats.go b/images/formats.go index 783eba2..f758cfb 100644 --- a/images/formats.go +++ b/images/formats.go @@ -1,7 +1,19 @@ package images +// Format represents the enum values for the formats in which +// generated images are returned. type Format int +const ( + // FormatInvalid represents and invalid Format option. + FormatInvalid Format = iota + // FormatURL specifies that the API will return a url to the generated image. + // URLs will expire after an hour. + FormatURL + // FormatB64JSON specifies that the API will return the image as Base64 data. + FormatB64JSON +) + // String implements the fmt.Stringer interface. func (f Format) String() string { return formatToString[f] @@ -20,17 +32,11 @@ func (f *Format) UnmarshalText(b []byte) error { return nil } - *f = FormatUnkown + *f = FormatInvalid return nil } -const ( - FormatUnkown Format = iota - FormatURL - FormatB64JSON -) - var formatToString = map[Format]string{ FormatURL: "url", FormatB64JSON: "b64_json", diff --git a/images/sizes.go b/images/sizes.go index 3d5c652..0e43ae7 100644 --- a/images/sizes.go +++ b/images/sizes.go @@ -1,13 +1,20 @@ package images +// Size represents the enum values for the image sizes that +// you can generate. Smaller sizes are faster to generate. type Size int -// Image sizes defined by the OpenAI API. -// TODO: make enum. const ( + // SizeInvalid represents and invalid Size option. SizeInvalid Size = iota + // Size256x256 specifies that the API will return an image that is + // 256x256 pixels. Size256x256 + // Size512x512 specifies that the API will return an image that is + // 512x512 pixels. Size512x512 + // Size1024x1024 specifies that the API will return an image that is + // 1024x1024 pixels. Size1024x1024 ) diff --git a/models/completions.go b/models/completions.go index a5d569f..7267ed9 100644 --- a/models/completions.go +++ b/models/completions.go @@ -1,42 +1,154 @@ package models +// Completion represents all models available for use with the Completions endpoint. type Completion int const ( - TextDavinci003 Completion = iota + // UnknownCompletion represents and invalid Completion model. + UnknownCompletion Completion = iota + // TextDavinci003 is the most capable GPT-3 model. Can do any task the other models can do, + // often with higher quality, longer output and better instruction-following. Also supports + // inserting completions within text. + // + // Supports up to 4,000 tokens. Training data up to Jun 2021. + TextDavinci003 + // TextDavinci002 is an older version of the most capable GPT-3 model. Can do any task the + // other models can do, often with higher quality, longer output and better + // instruction-following. Also supports inserting completions within text. + // + // Supports up to 4,000 tokens. + // + // Deprecated: Use TextDavinci003 instead. TextDavinci002 + // TextCurie001 is very capable, but faster and lower cost than Davinci. + // + // Supports up to 2,048 tokens. Training data up to Oct 2019. TextCurie001 + // TextBabbage001 is capable of straightforward tasks, very fast, and lower cost. + // + // Supports up to 2,048 tokens. Training data up to Oct 2019. TextBabbage001 + // TextAda001 is capable of very simple tasks, usually the fastest model in the + // GPT-3 series, and lowest cost. + // + // Supports up to 2,048 tokens. Training data up to Oct 2019. TextAda001 + // TextDavinci001 ... (?). TextDavinci001 + + // DavinciInstructBeta is the most capable model in the InstructGPT series. + // It is much better at following user intentions than GPT-3 while also being + // more truthful and less toxic. InstructGPT is better than GPT-3 at following + // English instructions. DavinciInstructBeta - Davinci + // CurieInstructBeta is very capable, but faster and lower cost than Davinci. + // It is much better at following user intentions than GPT-3 while also being + // more truthful and less toxic. InstructGPT is better than GPT-3 at following + // English instructions. CurieInstructBeta + + // Davinci most capable of the older versions of the GPT-3 models + // and is intended to be used with the fine-tuning endpoints. + Davinci + // Curie is very capable, but faster and lower cost than Davinci. It is + // an older version of the GPT-3 models and is intended to be used with + // the fine-tuning endpoints. Curie - Ada + // Babbage is capable of straightforward tasks, very fast, and lower cost. + // It is an older version of the GPT-3 models and is intended to be used + // with the fine-tuning endpoints. Babbage + // Ada is capable of very simple tasks, usually the fastest model in the + // GPT-3 series, and lowest cost. It is an older version of the GPT-3 + // models and is intended to be used with the fine-tuning endpoints. + Ada + // CodeDavinci002 is the most capable Codex model. Particularly good at + // translating natural language to code. In addition to completing code, + // also supports inserting completions within code. + // + // Supports up to 8,000 tokens. Training data up to Jun 2021. + CodeDavinci002 + // CodeCushman001 is almost as capable as Davinci Codex, but slightly faster. + // This speed advantage may make it preferable for real-time applications. + // + // Supports up to 2,048 tokens. + CodeCushman001 + // CodeDavinci001 is and older version of the most capable Codex model. + // Particularly good at translating natural language to code. In addition + // to completing code, also supports inserting completions within code. + // + // Deprecated: Use CodeDavinci002 instead. + CodeDavinci001 + + // TextDavinciInsert002 was a beta model released for insertion. + // + // Deprecated: Insertion should be done via the text models. + TextDavinciInsert002 + // TextDavinciInsert001 was a beta model released for insertion. + // + // Deprecated: Insertion should be done via the text models. + TextDavinciInsert001 ) +// String implements the fmt.Stringer interface. +func (c Completion) String() string { + return completionToString[c] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (c Completion) MarshalText() ([]byte, error) { + return []byte(c.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (c *Completion) UnmarshalText(b []byte) error { + if val, ok := stringToCompletion[(string(b))]; ok { + *c = val + return nil + } + + *c = UnknownCompletion + + return nil +} + var completionToString = map[Completion]string{ - TextDavinci003: "text-davinci-003", - TextDavinci002: "text-davinci-002", - TextCurie001: "text-curie-001", - TextBabbage001: "text-babbage-001", - TextAda001: "text-ada-001", - TextDavinci001: "text-davinci-001", - DavinciInstructBeta: "davinci-instruct-beta", - Davinci: "davinci", - CurieInstructBeta: "curie-instruct-beta", - Curie: "curie", - Ada: "ada", - Babbage: "babbage", + TextDavinci003: "text-davinci-003", + TextDavinci002: "text-davinci-002", + TextCurie001: "text-curie-001", + TextBabbage001: "text-babbage-001", + TextAda001: "text-ada-001", + TextDavinci001: "text-davinci-001", + DavinciInstructBeta: "davinci-instruct-beta", + CurieInstructBeta: "curie-instruct-beta", + Davinci: "davinci", + Curie: "curie", + Ada: "ada", + Babbage: "babbage", + CodeDavinci002: "code-davinci-002", + CodeCushman001: "code-cushman-001", + CodeDavinci001: "code-davinci-001", + TextDavinciInsert002: "text-davinci-insert-002", + TextDavinciInsert001: "text-davinci-insert-001", } -// Codex Defines the models provided by OpenAI. -// These models are designed for code-specific tasks, and use -// a different tokenizer which optimizes for whitespace. -const ( - CodexCodeDavinci002 = "code-davinci-002" - CodexCodeCushman001 = "code-cushman-001" - CodexCodeDavinci001 = "code-davinci-001" -) +var stringToCompletion = map[string]Completion{ + "text-davinci-003": TextDavinci003, + "text-davinci-002": TextDavinci002, + "text-curie-001": TextCurie001, + "text-babbage-001": TextBabbage001, + "text-ada-001": TextAda001, + "text-davinci-001": TextDavinci001, + "davinci-instruct-beta": DavinciInstructBeta, + "curie-instruct-beta": CurieInstructBeta, + "davinci": Davinci, + "curie": Curie, + "ada": Ada, + "babbage": Babbage, + "code-davinci-002": CodeDavinci002, + "code-cushman-001": CodeCushman001, + "code-davinci-001": CodeDavinci001, + "text-davinci-insert-002": TextDavinciInsert002, + "text-davinci-insert-001": TextDavinciInsert001, +} diff --git a/models/edits.go b/models/edits.go new file mode 100644 index 0000000..c63ef55 --- /dev/null +++ b/models/edits.go @@ -0,0 +1,21 @@ +package models + +// Edit represents all models available for use with the Edits endpoint. +type Edit int + +const ( + TextDavinciEdit001 Edit = iota + CodeDavinciEdit001 +) + +var editToString = map[Edit]string{ + // TextDavinciEdit001 can be used to edit text, rather than just completing it. + TextDavinciEdit001: "text-davinci-edit-001", + // CodeDavinciEdit001 can be used to edit code, rather than just completing it. + CodeDavinciEdit001: "code-davinci-edit-001", +} + +var stringToEdit = map[string]Edit{ + "text-davinci-edit-001": TextDavinciEdit001, + "code-davinci-edit-001": CodeDavinciEdit001, +} diff --git a/models/embeddings.go b/models/embeddings.go index a09baa1..ab323cc 100644 --- a/models/embeddings.go +++ b/models/embeddings.go @@ -28,7 +28,21 @@ func (e *Embedding) UnmarshalText(b []byte) error { } const ( + // Unknown represents an invalid Embedding model. Unknown Embedding = iota + + // AdaEmbeddingV2 is the second-generation embedding model. OpenAI recommends using + // text-embedding-ada-002 for nearly all use cases. It’s better, cheaper, and simpler to use. + // + // Supports up to 8191. Knowledge cutoff Sep 2021. + AdaEmbeddingV2 + + // The below models are first-generation models (those ending in -001) use the GPT-3 + // tokenizer and have a max input of 2046 tokens. First-generation embeddings are generated + // by five different model families tuned for three different tasks: text search, text similarity + // and code search. The search models come in pairs: one for short queries and one for long documents. + // Each family includes up to four models on a spectrum of quality and speed. + // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. AdaSimilarity // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. @@ -61,8 +75,6 @@ const ( BabbageCodeSearchCode // Deprecated: OpenAI recommends using text-embedding-ada-002 for nearly all use cases. BabbageCodeSearchText - - AdaEmbeddingV2 ) var enumToString = map[Embedding]string{ diff --git a/models/enum.go b/models/enum.go deleted file mode 100644 index 2640e7f..0000000 --- a/models/enum.go +++ /dev/null @@ -1 +0,0 @@ -package models diff --git a/moderation.go b/moderation.go index 54f4ff3..f4ef252 100644 --- a/moderation.go +++ b/moderation.go @@ -1,10 +1,9 @@ package openai import ( - "bytes" "context" "encoding/json" - "net/http" + "github.com/fabiustech/openai/routes" ) // ModerationRequest represents a request structure for moderation API. @@ -49,21 +48,17 @@ type ModerationResponse struct { Results []Result `json:"results"` } -// Moderations — perform a moderation api call over a string. -// Input can be an array or slice but a string will reduce the complexity. -func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (response ModerationResponse, err error) { - var reqBytes []byte - reqBytes, err = json.Marshal(request) +// Moderations ... +func (c *Client) Moderations(ctx context.Context, mr *ModerationRequest) (*ModerationResponse, error) { + var b, err = c.post(ctx, routes.Moderations, mr) if err != nil { - return + return nil, err } - req, err := http.NewRequest("POST", c.fullURL("/moderations"), bytes.NewBuffer(reqBytes)) - if err != nil { - return + var resp *ModerationResponse + if err = json.Unmarshal(b, resp); err != nil { + return nil, err } - req = req.WithContext(ctx) - err = c.sendRequest(req, &response) - return + return resp, nil } diff --git a/routes/routes.go b/routes/routes.go new file mode 100644 index 0000000..5b06c5d --- /dev/null +++ b/routes/routes.go @@ -0,0 +1,39 @@ +package routes + +const ( + // Completions is the route for the completions endpoint. + // https://beta.openai.com/docs/api-reference/completions + Completions = "completions" + // Edits is the route for the edits endpoint. + // https://beta.openai.com/docs/api-reference/edits + Edits = "edits" + // Embeddings is the route for the embeddings endpoint. + // https://beta.openai.com/docs/api-reference/embeddings + Embeddings = "embeddings" + + // Engines is the route for the engines endpoint. + // https://beta.openai.com/docs/api-reference/engines + // Deprecated: Use Models instead. + Engines = "engines" + + // Files is the route for the files endpoint. + // https://beta.openai.com/docs/api-reference/files + Files = "files" + + // + imagesBase = "images/" + + // ImageGenerations is the route for the create images endpoint. + // https://beta.openai.com/docs/api-reference/images/create + ImageGenerations = imagesBase + "generations" + // ImageEdits is the route for the create image edits endpoint. + // https://beta.openai.com/docs/api-reference/images/create-edit + ImageEdits = imagesBase + "edits" + // ImageVariations is the route for the create image variations endpoint. + // https://beta.openai.com/docs/api-reference/images/create-variation + ImageVariations = imagesBase + "variations" + + // Moderations is the route for the moderations endpoint. + // https://beta.openai.com/docs/api-reference/moderations + Moderations = "moderations" +) From 1cfc78a7181efebe538f5c9597a21f36b46b2e6d Mon Sep 17 00:00:00 2001 From: Andy Day Date: Sun, 15 Jan 2023 13:22:49 -0800 Subject: [PATCH 10/30] . --- client.go | 35 ++++++++++++++++++++++----------- client_test.go | 2 +- completions.go | 3 ++- edits.go | 6 ++++-- embeddings.go | 13 ++++++++----- files.go | 16 ++++++++------- fine_tunes.go | 3 +++ images.go | 13 +++++++------ models/edits.go | 29 +++++++++++++++++++++++++++- models/embeddings.go | 46 ++++++++++++++++++++++---------------------- moderation.go | 6 +++--- objects/objects.go | 16 +++++++++++++++ 12 files changed, 128 insertions(+), 60 deletions(-) create mode 100644 fine_tunes.go create mode 100644 objects/objects.go diff --git a/client.go b/client.go index 1c02705..ffe6a96 100644 --- a/client.go +++ b/client.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/fabiustech/openai/routes" "io" "mime/multipart" "net/http" @@ -20,15 +21,6 @@ const ( basePath = "v1" ) -func reqURL(route string) string { - var u = &url.URL{ - Scheme: scheme, - Host: host, - Path: path.Join(basePath, route), - } - return u.String() -} - // Client is OpenAI GPT-3 API client. type Client struct { token string @@ -122,7 +114,7 @@ func (c *Client) postFile(ctx context.Context, fr *FileRequest) ([]byte, error) w.Close() var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", reqURL(routeFiles), &b) + req, err = http.NewRequestWithContext(ctx, "POST", reqURL(routes.Files), &b) if err != nil { return nil, err } @@ -202,7 +194,28 @@ func (c *Client) delete(ctx context.Context, path string) error { return interpretResponse(resp) } -// TODO: implement. +func reqURL(route string) string { + var u = &url.URL{ + Scheme: scheme, + Host: host, + Path: path.Join(basePath, route), + } + return u.String() +} + func interpretResponse(resp *http.Response) error { + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusBadRequest { + var b, err = io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("error, status code: %d", resp.StatusCode) + } + var er *ErrorResponse + if err = json.Unmarshal(b, er); err != nil || er.Error == nil { + return fmt.Errorf("error, status code: %d, msg: %s", resp.StatusCode, string(b)) + } + + return er.Error + } + return nil } diff --git a/client_test.go b/client_test.go index 1168d6b..d2903ac 100644 --- a/client_test.go +++ b/client_test.go @@ -278,7 +278,7 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { Created: uint64(time.Now().Unix()), } for i := 0; i < imageReq.N; i++ { - imageData := openai.ImageResponseDataInner{} + imageData := openai.ImageData{} switch imageReq.ResponseFormat { case openai.CreateImageResponseFormatURL, "": imageData.URL = "https://example.com/image.png" diff --git a/completions.go b/completions.go index 7837c1d..5cc9e4a 100644 --- a/completions.go +++ b/completions.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "github.com/fabiustech/openai/models" + "github.com/fabiustech/openai/objects" "github.com/fabiustech/openai/routes" ) @@ -108,7 +109,7 @@ type LogprobResult struct { // CompletionResponse represents a response structure for completion API. type CompletionResponse struct { ID string `json:"id"` - Object string `json:"object"` + Object objects.Object `json:"object"` Created uint64 `json:"created"` Model models.Completion `json:"model"` Choices []*CompletionChoice `json:"choices"` diff --git a/edits.go b/edits.go index 1797cd7..64da3e9 100644 --- a/edits.go +++ b/edits.go @@ -3,12 +3,14 @@ package openai import ( "context" "encoding/json" + "github.com/fabiustech/openai/models" + "github.com/fabiustech/openai/objects" "github.com/fabiustech/openai/routes" ) // EditsRequest represents a request structure for Edits API. type EditsRequest struct { - Model string `json:"model"` + Model models.Edit `json:"model"` // Input is the input text to use as a starting point for the edit. // Defaults to "". Input *string `json:"input,omitempty"` @@ -38,7 +40,7 @@ type EditsChoice struct { // EditsResponse represents a response structure for Edits API. type EditsResponse struct { - Object string `json:"object"` // "edit" + Object objects.Object `json:"object"` // "edit" Created uint64 `json:"created"` Usage *Usage `json:"usage"` Choices []*EditsChoice `json:"choices"` diff --git a/embeddings.go b/embeddings.go index 402c409..1673674 100644 --- a/embeddings.go +++ b/embeddings.go @@ -4,6 +4,8 @@ import ( "context" "encoding/json" "github.com/fabiustech/openai/models" + "github.com/fabiustech/openai/objects" + "github.com/fabiustech/openai/routes" ) // Embedding is a special format of data representation that can be easily utilized by machine @@ -13,14 +15,15 @@ import ( // between two inputs in the original format. For example, if two texts are similar, // then their vector representations should also be similar. type Embedding struct { - Object string `json:"object"` - Embedding []float64 `json:"embedding"` - Index int `json:"index"` + Object objects.Object `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` } // EmbeddingResponse is the response from a Create embeddings request. +// Todo: Wrap type EmbeddingResponse struct { - Object string `json:"object"` + Object objects.Object `json:"object"` Data []Embedding `json:"data"` Model models.Embedding `json:"model"` Usage Usage `json:"usage"` @@ -45,7 +48,7 @@ type EmbeddingRequest struct { // CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. // https://beta.openai.com/docs/api-reference/embeddings/create func (c *Client) CreateEmbeddings(ctx context.Context, request *EmbeddingRequest) (*EmbeddingResponse, error) { - var b, err = c.post(ctx, embeddingsRoute, request) + var b, err = c.post(ctx, routes.Embeddings, request) if err != nil { return nil, err } diff --git a/files.go b/files.go index 04c2f46..80c3977 100644 --- a/files.go +++ b/files.go @@ -3,6 +3,7 @@ package openai import ( "context" "encoding/json" + "github.com/fabiustech/openai/objects" "github.com/fabiustech/openai/routes" "net/url" "path" @@ -17,16 +18,17 @@ type FileRequest struct { // File struct represents an OpenAPI file. type File struct { - Bytes int `json:"bytes"` - CreatedAt int `json:"created_at"` - ID string `json:"id"` - FileName string `json:"filename"` - Object string `json:"object"` - Owner string `json:"owner"` - Purpose string `json:"purpose"` + Bytes int `json:"bytes"` + CreatedAt int `json:"created_at"` + ID string `json:"id"` + FileName string `json:"filename"` + Object objects.Object `json:"object"` + Owner string `json:"owner"` + Purpose string `json:"purpose"` } // FilesList is a list of files that belong to the user or organization. +// TODO: wrap. type FilesList struct { Files []File `json:"data"` } diff --git a/fine_tunes.go b/fine_tunes.go new file mode 100644 index 0000000..434a887 --- /dev/null +++ b/fine_tunes.go @@ -0,0 +1,3 @@ +package openai + +// TODO: diff --git a/images.go b/images.go index b05fd1e..6c1f223 100644 --- a/images.go +++ b/images.go @@ -34,14 +34,15 @@ type ImageRequestFields struct { // ImageResponse represents a response structure for image API. type ImageResponse struct { - Created uint64 `json:"created,omitempty"` - Data []*ImageResponseDataInner `json:"data,omitempty"` + Created uint64 `json:"created,omitempty"` + Data []*ImageData `json:"data,omitempty"` } -// ImageResponseDataInner represents a response data structure for image API. -type ImageResponseDataInner struct { - URL string `json:"url,omitempty"` - B64JSON string `json:"b64_json,omitempty"` +// ImageData represents a response data structure for image API. +// Only one field will be non-nil. +type ImageData struct { + URL *string `json:"url,omitempty"` + B64JSON *string `json:"b64_json,omitempty"` } // CreateImage ... diff --git a/models/edits.go b/models/edits.go index c63ef55..d01474f 100644 --- a/models/edits.go +++ b/models/edits.go @@ -4,10 +4,37 @@ package models type Edit int const ( - TextDavinciEdit001 Edit = iota + // UnknownEdit represents and invalid Edit model. + UnknownEdit Edit = iota + // TextDavinciEdit001 ... + TextDavinciEdit001 + // CodeDavinciEdit001 ... CodeDavinciEdit001 ) +// String implements the fmt.Stringer interface. +func (e Edit) String() string { + return editToString[e] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (e Edit) MarshalText() ([]byte, error) { + return []byte(e.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (e *Edit) UnmarshalText(b []byte) error { + if val, ok := stringToEdit[(string(b))]; ok { + *e = val + return nil + } + + *e = UnknownEdit + + return nil +} + var editToString = map[Edit]string{ // TextDavinciEdit001 can be used to edit text, rather than just completing it. TextDavinciEdit001: "text-davinci-edit-001", diff --git a/models/embeddings.go b/models/embeddings.go index ab323cc..569ddc0 100644 --- a/models/embeddings.go +++ b/models/embeddings.go @@ -4,29 +4,6 @@ package models // to generate Embedding vectors. type Embedding int -// String implements the fmt.Stringer interface. -func (e Embedding) String() string { - return enumToString[e] -} - -// MarshalText implements the encoding.TextMarshaler interface. -func (e Embedding) MarshalText() ([]byte, error) { - return []byte(e.String()), nil -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -// On unrecognized value, it sets |e| to Unknown. -func (e *Embedding) UnmarshalText(b []byte) error { - if val, ok := stringToEnum[(string(b))]; ok { - *e = val - return nil - } - - *e = Unknown - - return nil -} - const ( // Unknown represents an invalid Embedding model. Unknown Embedding = iota @@ -77,6 +54,29 @@ const ( BabbageCodeSearchText ) +// String implements the fmt.Stringer interface. +func (e Embedding) String() string { + return enumToString[e] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (e Embedding) MarshalText() ([]byte, error) { + return []byte(e.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (e *Embedding) UnmarshalText(b []byte) error { + if val, ok := stringToEnum[(string(b))]; ok { + *e = val + return nil + } + + *e = Unknown + + return nil +} + var enumToString = map[Embedding]string{ AdaSimilarity: "text-similarity-ada-001", BabbageSimilarity: "text-similarity-babbage-001", diff --git a/moderation.go b/moderation.go index f4ef252..d5846c0 100644 --- a/moderation.go +++ b/moderation.go @@ -14,9 +14,9 @@ type ModerationRequest struct { // Result represents one of possible moderation results. type Result struct { - Categories ResultCategories `json:"categories"` - CategoryScores ResultCategoryScores `json:"category_scores"` - Flagged bool `json:"flagged"` + Categories *ResultCategories `json:"categories"` + CategoryScores *ResultCategoryScores `json:"category_scores"` + Flagged bool `json:"flagged"` } // ResultCategories represents Categories of Result. diff --git a/objects/objects.go b/objects/objects.go new file mode 100644 index 0000000..5a8bb9a --- /dev/null +++ b/objects/objects.go @@ -0,0 +1,16 @@ +package objects + +type Object int + +const ( + Model = "model" + List = "list" + + TextCompletion = "text_completion" + CodeCompletion = "code_completion" + Edit = "edit" + Embedding = "embedding" + File = "file" + FineTune = "fine-tune" + Engine = "engine" +) From 543f5c2c23b5fca0cdd661c463f282ff38d93b0a Mon Sep 17 00:00:00 2001 From: Andy Day Date: Sun, 15 Jan 2023 13:41:48 -0800 Subject: [PATCH 11/30] more --- client_test.go | 151 +++++++++++++++++++++++---------------------- completions.go | 2 +- edits.go | 2 +- objects/objects.go | 43 ++++++++++--- params/params.go | 6 ++ 5 files changed, 119 insertions(+), 85 deletions(-) create mode 100644 params/params.go diff --git a/client_test.go b/client_test.go index d2903ac..d89a6c1 100644 --- a/client_test.go +++ b/client_test.go @@ -6,7 +6,11 @@ import ( "encoding/json" "fmt" "github.com/fabiustech/openai" - "io/ioutil" + "github.com/fabiustech/openai/images" + "github.com/fabiustech/openai/models" + "github.com/fabiustech/openai/objects" + "github.com/fabiustech/openai/params" + "io" "log" "net/http" "net/http/httptest" @@ -52,12 +56,12 @@ func TestAPI(t *testing.T) { } } // else skip - embeddingReq := openai.EmbeddingRequest{ + embeddingReq := &openai.EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", }, - Model: openai.AdaSearchQuery, + Model: models.AdaSearchQuery, } _, err = c.CreateEmbeddings(ctx, embeddingReq) if err != nil { @@ -75,11 +79,11 @@ func TestCompletions(t *testing.T) { client := openai.NewClient(testAPIToken) ctx := context.Background() - client.BaseURL = ts.URL + "/v1" + // client.BaseURL = ts.URL + "/v1" - req := openai.CompletionRequest{ - MaxTokens: 5, - Model: "ada", + req := &openai.CompletionRequest{ + MaxTokens: params.Optional(5), + Model: models.Ada, } req.Prompt = "Lorem ipsum" _, err = client.CreateCompletion(ctx, req) @@ -98,49 +102,50 @@ func TestEdits(t *testing.T) { client := openai.NewClient(testAPIToken) ctx := context.Background() - client.BaseURL = ts.URL + "/v1" + // client.BaseURL = ts.URL + "/v1" // create an edit request - model := "ada" - editReq := openai.EditsRequest{ - Model: &model, + + editReq := &openai.EditsRequest{ + Model: models.TextDavinciEdit001, Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + " ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" + " ex ea commodo consequat. Duis aute irure dolor in reprehe", Instruction: "test instruction", - N: 3, + N: params.Optional(3), } response, err := client.Edits(ctx, editReq) if err != nil { t.Fatalf("Edits error: %v", err) } - if len(response.Choices) != editReq.N { + if len(response.Choices) != *editReq.N { t.Fatalf("edits does not properly return the correct number of choices") } } func TestEmbedding(t *testing.T) { - embeddedModels := []openai.EmbeddingModel{ - openai.AdaSimilarity, - openai.BabbageSimilarity, - openai.CurieSimilarity, - openai.DavinciSimilarity, - openai.AdaSearchDocument, - openai.AdaSearchQuery, - openai.BabbageSearchDocument, - openai.BabbageSearchQuery, - openai.CurieSearchDocument, - openai.CurieSearchQuery, - openai.DavinciSearchDocument, - openai.DavinciSearchQuery, - openai.AdaCodeSearchCode, - openai.AdaCodeSearchText, - openai.BabbageCodeSearchCode, - openai.BabbageCodeSearchText, + embeddedModels := []models.Embedding{ + models.AdaSimilarity, + models.BabbageSimilarity, + models.CurieSimilarity, + models.DavinciSimilarity, + models.AdaSearchDocument, + models.AdaSearchQuery, + models.BabbageSearchDocument, + models.BabbageSearchQuery, + models.CurieSearchDocument, + models.CurieSearchQuery, + models.DavinciSearchDocument, + models.DavinciSearchQuery, + models.AdaCodeSearchCode, + models.AdaCodeSearchText, + models.BabbageCodeSearchCode, + models.BabbageCodeSearchText, + models.AdaEmbeddingV2, } for _, model := range embeddedModels { - embeddingReq := openai.EmbeddingRequest{ + embeddingReq := &openai.EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", @@ -160,16 +165,16 @@ func TestEmbedding(t *testing.T) { } // getEditBody Returns the body of the request to create an edit. -func getEditBody(r *http.Request) (openai.EditsRequest, error) { - edit := openai.EditsRequest{} +func getEditBody(r *http.Request) (*openai.EditsRequest, error) { + edit := &openai.EditsRequest{} // read the request body - reqBody, err := ioutil.ReadAll(r.Body) + reqBody, err := io.ReadAll(r.Body) if err != nil { - return openai.EditsRequest{}, err + return nil, err } err = json.Unmarshal(reqBody, &edit) if err != nil { - return openai.EditsRequest{}, err + return nil, err } return edit, nil } @@ -183,29 +188,29 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var editReq openai.EditsRequest + var editReq *openai.EditsRequest editReq, err = getEditBody(r) if err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } // create a response - res := openai.EditsResponse{ - Object: "test-object", + res := &openai.EditsResponse{ + Object: objects.Edit, Created: uint64(time.Now().Unix()), } // edit and calculate token usage editString := "edited by mocked OpenAI server :)" - inputTokens := numTokens(editReq.Input+editReq.Instruction) * editReq.N - completionTokens := int(float32(len(editString))/4) * editReq.N - for i := 0; i < editReq.N; i++ { + inputTokens := numTokens(editReq.Input+editReq.Instruction) * *editReq.N + completionTokens := int(float32(len(editString))/4) * *editReq.N + for i := 0; i < *editReq.N; i++ { // instruction will be hidden and only seen by OpenAI - res.Choices = append(res.Choices, openai.EditsChoice{ + res.Choices = append(res.Choices, &openai.EditsChoice{ Text: editReq.Input + editString, Index: i, }) } - res.Usage = openai.Usage{ + res.Usage = &openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -223,14 +228,14 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var completionReq openai.CompletionRequest + var completionReq *openai.CompletionRequest if completionReq, err = getCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := openai.CompletionResponse{ + res := &openai.CompletionResponse{ ID: strconv.Itoa(int(time.Now().Unix())), - Object: "test-object", + Object: objects.TextCompletion, Created: uint64(time.Now().Unix()), // would be nice to validate Model during testing, but // this may not be possible with how much upkeep @@ -238,20 +243,20 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Model: completionReq.Model, } // create completions - for i := 0; i < completionReq.N; i++ { + for i := 0; i < *completionReq.N; i++ { // generate a random string of length completionReq.Length - completionStr := strings.Repeat("a", completionReq.MaxTokens) + completionStr := strings.Repeat("a", *completionReq.MaxTokens) if completionReq.Echo { completionStr = completionReq.Prompt + completionStr } - res.Choices = append(res.Choices, openai.CompletionChoice{ + res.Choices = append(res.Choices, &openai.CompletionChoice{ Text: completionStr, Index: i, }) } - inputTokens := numTokens(completionReq.Prompt) * completionReq.N - completionTokens := completionReq.MaxTokens * completionReq.N - res.Usage = openai.Usage{ + inputTokens := numTokens(completionReq.Prompt) * *completionReq.N + completionTokens := *completionReq.MaxTokens * *completionReq.N + res.Usage = &openai.Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -269,22 +274,22 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var imageReq openai.CreateImageRequest + var imageReq *openai.CreateImageRequest if imageReq, err = getImageBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := openai.ImageResponse{ + res := &openai.ImageResponse{ Created: uint64(time.Now().Unix()), } - for i := 0; i < imageReq.N; i++ { - imageData := openai.ImageData{} + for i := 0; i < *imageReq.N; i++ { + var imageData = &openai.ImageData{} switch imageReq.ResponseFormat { - case openai.CreateImageResponseFormatURL, "": - imageData.URL = "https://example.com/image.png" - case openai.CreateImageResponseFormatB64JSON: + case params.Optional(images.FormatURL), nil: + imageData.URL = params.Optional("https://example.com/image.png") + case params.Optional(images.FormatB64JSON): // This decodes to "{}" in base64. - imageData.B64JSON = "e30K" + imageData.B64JSON = params.Optional("e30K") default: http.Error(w, "invalid response format", http.StatusBadRequest) return @@ -296,31 +301,31 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { } // getCompletionBody Returns the body of the request to create a completion. -func getCompletionBody(r *http.Request) (openai.CompletionRequest, error) { - completion := openai.CompletionRequest{} +func getCompletionBody(r *http.Request) (*openai.CompletionRequest, error) { + var completion = &openai.CompletionRequest{} // read the request body - reqBody, err := ioutil.ReadAll(r.Body) + reqBody, err := io.ReadAll(r.Body) if err != nil { - return openai.CompletionRequest{}, err + return nil, err } err = json.Unmarshal(reqBody, &completion) if err != nil { - return openai.CompletionRequest{}, err + return nil, err } return completion, nil } // getImageBody Returns the body of the request to create a image. -func getImageBody(r *http.Request) (openai.CreateImageRequest, error) { - image := openai.CreateImageRequest{} +func getImageBody(r *http.Request) (*openai.CreateImageRequest, error) { + var image = &openai.CreateImageRequest{} // read the request body - reqBody, err := ioutil.ReadAll(r.Body) + var reqBody, err = io.ReadAll(r.Body) if err != nil { - return openai.CreateImageRequest{}, err + return nil, err } err = json.Unmarshal(reqBody, &image) if err != nil { - return openai.CreateImageRequest{}, err + return nil, err } return image, nil } @@ -343,9 +348,9 @@ func TestImages(t *testing.T) { client := openai.NewClient(testAPIToken) ctx := context.Background() - client.BaseURL = ts.URL + "/v1" + // client.BaseURL = ts.URL + "/v1" - req := openai.CreateImageRequest{} + req := &openai.CreateImageRequest{} req.Prompt = "Lorem ipsum" _, err = client.CreateImage(ctx, req) if err != nil { diff --git a/completions.go b/completions.go index 5cc9e4a..16fd27a 100644 --- a/completions.go +++ b/completions.go @@ -17,7 +17,7 @@ type CompletionRequest struct { // or array of token arrays. Note that <|endoftext|> is the document separator that the model sees during // training, so if a prompt is not specified the model will generate as if from the beginning of a new document. // Defaults to <|endoftext|>. - Prompt *string `json:"prompt,omitempty"` + Prompt string `json:"prompt,omitempty"` // Suffix specifies the suffix that comes after a completion of inserted text. // Defaults to null. Suffix *string `json:"suffix,omitempty"` diff --git a/edits.go b/edits.go index 64da3e9..54a2d8d 100644 --- a/edits.go +++ b/edits.go @@ -13,7 +13,7 @@ type EditsRequest struct { Model models.Edit `json:"model"` // Input is the input text to use as a starting point for the edit. // Defaults to "". - Input *string `json:"input,omitempty"` + Input string `json:"input,omitempty"` // Instruction is the instruction that tells the model how to edit the prompt. Instruction string `json:"instruction,omitempty"` // N specifies how many edits to generate for the input and instruction. diff --git a/objects/objects.go b/objects/objects.go index 5a8bb9a..381ba9a 100644 --- a/objects/objects.go +++ b/objects/objects.go @@ -3,14 +3,37 @@ package objects type Object int const ( - Model = "model" - List = "list" - - TextCompletion = "text_completion" - CodeCompletion = "code_completion" - Edit = "edit" - Embedding = "embedding" - File = "file" - FineTune = "fine-tune" - Engine = "engine" + Model Object = iota + List + TextCompletion + CodeCompletion + Edit + Embedding + File + FineTune + Engine ) + +var objectToString = map[Object]string{ + Model: "model", + List: "list", + TextCompletion: "text_completion", + CodeCompletion: "code_completion", + Edit: "edit", + Embedding: "embedding", + File: "file", + FineTune: "fine-tune", + Engine: "engine", +} + +var stringToModel = map[string]Object{ + "model": Model, + "list": List, + "text_completion": TextCompletion, + "code_completion": CodeCompletion, + "edit": Edit, + "embedding": Embedding, + "file": File, + "fine-tune": FineTune, + "engine": Engine, +} diff --git a/params/params.go b/params/params.go new file mode 100644 index 0000000..8fc1b8b --- /dev/null +++ b/params/params.go @@ -0,0 +1,6 @@ +package params + +// Optional returns a pointer to |v|. +func Optional[T any](v T) *T { + return &v +} From b4a9dad83f46f92abc080b8a1c4ba098c3553fc7 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Sun, 15 Jan 2023 13:57:33 -0800 Subject: [PATCH 12/30] lots of doc comments --- common.go | 7 +++++-- doc.go | 3 +++ images/formats.go | 2 ++ models/completions.go | 2 ++ objects/objects.go | 42 ++++++++++++++++++++++++++++++++++++++++-- params/params.go | 2 ++ routes/routes.go | 1 + 7 files changed, 55 insertions(+), 4 deletions(-) create mode 100644 doc.go diff --git a/common.go b/common.go index 3121ede..ad84dc6 100644 --- a/common.go +++ b/common.go @@ -2,7 +2,10 @@ package openai // Usage Represents the total token usage per request to OpenAI. type Usage struct { - PromptTokens int `json:"prompt_tokens"` + // PromptTokens is the number of tokens in the passed prompt. + PromptTokens int `json:"prompt_tokens"` + // CompletionTokens is the number of tokens in the completion response. CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` + // Total tokens is the sum of PromptTokens and CompletionTokens. + TotalTokens int `json:"total_tokens"` } diff --git a/doc.go b/doc.go new file mode 100644 index 0000000..3fa1a00 --- /dev/null +++ b/doc.go @@ -0,0 +1,3 @@ +// Package openai is a client library for interacting with the OpenAI API. +// It supports all non-deprecated endpoints (as well as the Engines endpoint). +package openai diff --git a/images/formats.go b/images/formats.go index f758cfb..e590baa 100644 --- a/images/formats.go +++ b/images/formats.go @@ -1,3 +1,5 @@ +// Package images contains the enum values which represent the various +// image formats and sizes returned by the OpenAI image endpoints. package images // Format represents the enum values for the formats in which diff --git a/models/completions.go b/models/completions.go index 7267ed9..ce40670 100644 --- a/models/completions.go +++ b/models/completions.go @@ -1,3 +1,5 @@ +// Package models contains the enum values which represent the various +// models used by all OpenAI endpoints. package models // Completion represents all models available for use with the Completions endpoint. diff --git a/objects/objects.go b/objects/objects.go index 381ba9a..3dc2855 100644 --- a/objects/objects.go +++ b/objects/objects.go @@ -1,19 +1,57 @@ +// Package objects contains the enum values which represent the various +// objects returned by all OpenAI endpoints. package objects +// Object enumerates the various object types returned by OpenAI endpoints. type Object int const ( - Model Object = iota + // Unknown is an invalid object. + Unknown Object = iota + // Model is a model (can be either a base model or fine-tuned). + Model + // List is a list of other objects. List + // TextCompletion is a text completion. TextCompletion + // CodeCompletion is a code completion. CodeCompletion + // Edit is an edit. Edit + // Embedding is an embedding. Embedding + // File is a file. File + // FineTune is a fine-tuned model. FineTune + // Engine represents an engine. + // Deprecated: use Model instead. Engine ) +// String implements the fmt.Stringer interface. +func (o Object) String() string { + return objectToString[o] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (o Object) MarshalText() ([]byte, error) { + return []byte(o.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (o *Object) UnmarshalText(b []byte) error { + if val, ok := stringToObject[(string(b))]; ok { + *o = val + return nil + } + + *o = Unknown + + return nil +} + var objectToString = map[Object]string{ Model: "model", List: "list", @@ -26,7 +64,7 @@ var objectToString = map[Object]string{ Engine: "engine", } -var stringToModel = map[string]Object{ +var stringToObject = map[string]Object{ "model": Model, "list": List, "text_completion": TextCompletion, diff --git a/params/params.go b/params/params.go index 8fc1b8b..5555311 100644 --- a/params/params.go +++ b/params/params.go @@ -1,3 +1,5 @@ +// Package params provides a helper function to simplify setting optional +// parameters in struct literals. package params // Optional returns a pointer to |v|. diff --git a/routes/routes.go b/routes/routes.go index 5b06c5d..d3945ad 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -1,3 +1,4 @@ +// Package routes contains constants for all OpenAI endpoint routes. package routes const ( From dcf93f6f807fea715e3af7499b13b2deb2e876a4 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Sun, 15 Jan 2023 14:10:53 -0800 Subject: [PATCH 13/30] very close --- client.go | 14 ++++++++++++++ common.go | 5 +++-- embeddings.go | 8 +++----- engines.go | 9 ++------- files.go | 25 ++----------------------- list.go | 10 ++++++++++ 6 files changed, 34 insertions(+), 37 deletions(-) create mode 100644 list.go diff --git a/client.go b/client.go index ffe6a96..dcf8c87 100644 --- a/client.go +++ b/client.go @@ -154,6 +154,20 @@ func readFile(path string) (io.ReadCloser, error) { return resp.Body, nil } +// isUrl is a helper function that determines whether the given FilePath +// is a remote URL or a local file path. +func isURL(path string) bool { + if _, err := url.ParseRequestURI(path); err != nil { + return false + } + + if u, err := url.Parse(path); err != nil || u.Scheme == "" || u.Host == "" { + return false + } + + return true +} + func (c *Client) get(ctx context.Context, path string) ([]byte, error) { var req, err = http.NewRequestWithContext(ctx, "POST", reqURL(path), nil) if err != nil { diff --git a/common.go b/common.go index ad84dc6..b49a141 100644 --- a/common.go +++ b/common.go @@ -2,10 +2,11 @@ package openai // Usage Represents the total token usage per request to OpenAI. type Usage struct { - // PromptTokens is the number of tokens in the passed prompt. + // PromptTokens is the number of tokens in the request's prompt. PromptTokens int `json:"prompt_tokens"` // CompletionTokens is the number of tokens in the completion response. - CompletionTokens int `json:"completion_tokens"` + // Will not be set for requests to the embeddings endpoint. + CompletionTokens int `json:"completion_tokens,omitempty"` // Total tokens is the sum of PromptTokens and CompletionTokens. TotalTokens int `json:"total_tokens"` } diff --git a/embeddings.go b/embeddings.go index 1673674..29fc24d 100644 --- a/embeddings.go +++ b/embeddings.go @@ -21,12 +21,10 @@ type Embedding struct { } // EmbeddingResponse is the response from a Create embeddings request. -// Todo: Wrap type EmbeddingResponse struct { - Object objects.Object `json:"object"` - Data []Embedding `json:"data"` - Model models.Embedding `json:"model"` - Usage Usage `json:"usage"` + *List[*Embedding] + Model models.Embedding + Usage *Usage } // EmbeddingRequest is the input to a Create embeddings request. diff --git a/engines.go b/engines.go index 09ba942..511e5e0 100644 --- a/engines.go +++ b/engines.go @@ -15,23 +15,18 @@ type Engine struct { Ready bool `json:"ready"` } -// EnginesList is a list of engines. -type EnginesList struct { - Engines []*Engine `json:"data"` -} - // ListEngines Lists the currently available engines, and provides basic // information about each option such as the owner and availability. // // Deprecated: Please use their replacement, Models, instead. // https://beta.openai.com/docs/api-reference/models -func (c *Client) ListEngines(ctx context.Context) (*EnginesList, error) { +func (c *Client) ListEngines(ctx context.Context) (*List[*Engine], error) { var b, err = c.get(ctx, routes.Engines) if err != nil { return nil, err } - var el *EnginesList + var el *List[*Engine] if err = json.Unmarshal(b, el); err != nil { return nil, err } diff --git a/files.go b/files.go index 80c3977..4d860dc 100644 --- a/files.go +++ b/files.go @@ -5,7 +5,6 @@ import ( "encoding/json" "github.com/fabiustech/openai/objects" "github.com/fabiustech/openai/routes" - "net/url" "path" ) @@ -27,26 +26,6 @@ type File struct { Purpose string `json:"purpose"` } -// FilesList is a list of files that belong to the user or organization. -// TODO: wrap. -type FilesList struct { - Files []File `json:"data"` -} - -// isUrl is a helper function that determines whether the given FilePath -// is a remote URL or a local file path. -func isURL(path string) bool { - if _, err := url.ParseRequestURI(path); err != nil { - return false - } - - if u, err := url.Parse(path); err != nil || u.Scheme == "" || u.Host == "" { - return false - } - - return true -} - // CreateFile uploads a jsonl file to GPT3 // FilePath can be either a local file path or a URL. func (c *Client) CreateFile(ctx context.Context, fr *FileRequest) (*File, error) { @@ -70,13 +49,13 @@ func (c *Client) DeleteFile(ctx context.Context, id string) error { // ListFiles Lists the currently available files, // and provides basic information about each file such as the file name and purpose. -func (c *Client) ListFiles(ctx context.Context) (*FilesList, error) { +func (c *Client) ListFiles(ctx context.Context) (*List[*File], error) { var b, err = c.get(ctx, routes.Files) if err != nil { return nil, err } - var fl *FilesList + var fl *List[*File] if err = json.Unmarshal(b, fl); err != nil { return nil, err } diff --git a/list.go b/list.go new file mode 100644 index 0000000..5f5eda6 --- /dev/null +++ b/list.go @@ -0,0 +1,10 @@ +package openai + +import ( + "github.com/fabiustech/openai/objects" +) + +type List[T any] struct { + Object objects.Object `json:"object"` + Data []T `json:"data"` +} From 56a1db12a34fc8f443d58e3f92a6e9a2e4d6f203 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Sun, 15 Jan 2023 23:36:53 -0800 Subject: [PATCH 14/30] wip --- fine_tunes.go | 137 ++++++++++++++++++++++++++++++++++++++++++++- objects/objects.go | 3 + routes/routes.go | 5 +- 3 files changed, 143 insertions(+), 2 deletions(-) diff --git a/fine_tunes.go b/fine_tunes.go index 434a887..4b1b527 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -1,3 +1,138 @@ package openai -// TODO: +import ( + "context" + "encoding/json" + "github.com/fabiustech/openai/models" + "github.com/fabiustech/openai/objects" + "github.com/fabiustech/openai/routes" +) + +// FineTuneRequest ... +type FineTuneRequest struct { + // TrainingFile specifies the ID of an uploaded file that contains training data. See upload file for how to upload + // a file. + // + // https://beta.openai.com/docs/api-reference/files/upload + // + // Your dataset must be formatted as a JSONL file, where each training example is a JSON object with the keys + // "prompt" and "completion". Additionally, you must upload your file with the purpose fine-tune. See the + // fine-tuning guide for more details: + // + // https://beta.openai.com/docs/guides/fine-tuning/creating-training-data + TrainingFile string `json:"training_file"` + // ValidationFile specifies the ID of an uploaded file that contains validation data. If you provide this file, the + // data is used to generate validation metrics periodically during fine-tuning. These metrics can be viewed in the + // fine-tuning results file. + // + // https://beta.openai.com/docs/guides/fine-tuning/analyzing-your-fine-tuned-model + // + // Your train and validation data should be mutually exclusive. Your dataset must be formatted as a JSONL file, + // where each validation example is a JSON object with the keys "prompt" and "completion". Additionally, you must + // upload your file with the purpose fine-tune. See the fine-tuning guide for more details: + // + // https://beta.openai.com/docs/guides/fine-tuning/creating-training-data + ValidationFile *string `json:"validation_file,omitempty"` + // Model specifies the name of the base model to fine-tune. You can select one of "ada", "babbage", "curie", + // "davinci", or a fine-tuned model created after 2022-04-21. To learn more about these models, see the Models + // documentation. + // Defaults to "curie". + Model *models.Completion `json:"model,omitempty"` + // NEpochs specifies the number of epochs to train the model for. An epoch refers to one full cycle through + // the training dataset. + // Defaults to 4. + NEpochs *int `json:"n_epochs,omitempty"` + // BatchSize specifies the batch size to use for training. The batch size is the number of training examples used + // to train a single forward and backward pass. By default, the batch size will be dynamically configured to be + // ~0.2% of the number of examples in the training set, capped at 256 - in general, we've found that larger batch + // sizes tend to work better for larger datasets. + // Defaults to null. + BatchSize *int `json:"batch_size,omitempty"` + // LearningRateMultiplier specifies the learning rate multiplier to use for training. The fine-tuning learning rate + // is the original learning rate used for pretraining multiplied by this value. By default, the learning rate + // multiplier is the 0.05, 0.1, or 0.2 depending on final batch_size (larger learning rates tend to perform better + // with larger batch sizes). We recommend experimenting with values in the range 0.02 to 0.2 to see what produces + // the best results. + // Defaults to null. + LearningRateMultiplier *int `json:"learning_rate_multiplier,omitempty"` + // PromptLossWeight specifies the weight to use for loss on the prompt tokens. This controls how much the model + // tries to learn to generate the prompt (as compared to the completion which always has a weight of 1.0), and can + // add a stabilizing effect to training when completions are short. If prompts are extremely long (relative to + // completions), it may make sense to reduce this weight so as to avoid over-prioritizing learning the prompt. + // Defaults to 0.01. + PromptLossWeight *int `json:"prompt_loss_weight,omitempty"` + // ComputeClassificationMetrics calculates classification-specific metrics such as accuracy and F-1 score using the + // validation set at the end of every epoch if set to true. These metrics can be viewed in the results file. + // + // https://beta.openai.com/docs/guides/fine-tuning/analyzing-your-fine-tuned-model + // + // In order to compute classification metrics, you must provide a ValidationFile. Additionally, you must specify + // ClassificationNClasses for multiclass classification or ClassificationPositiveClass for binary classification. + ComputeClassificationMetrics bool `json:"compute_classification_metrics,omitempty"` + // ClassificationNClasses specifies the number of classes in a classification task. This parameter is required for + // multiclass classification. + // Defaults to null. + ClassificationNClasses *int `json:"classification_n_classes,omitempty"` + // ClassificationPositiveClass specifies the positive class in binary classification. This parameter is needed to + // generate precision, recall, and F1 metrics when doing binary classification. + // Defaults to null. + ClassificationPositiveClass *string `json:"classification_positive_class,omitempty"` + // ClassificationBetas specifies that if provided, we calculate F-beta scores at the specified beta values. The + // F-beta score is a generalization of F-1 score. This is only used for binary classification. With a beta of 1 + // (i.e. the F-1 score), precision and recall are given the same weight. A larger beta score puts more weight on + // recall and less on precision. A smaller beta score puts more weight on precision and less on recall. + // Defaults to null. + ClassificationBetas []float32 `json:"classification_betas,omitempty"` + // Suffix specifies a string of up to 40 characters that will be added to your fine-tuned model name. For example, + // a suffix of "custom-model-name" would produce a model name like + // ada:ft-your-org:custom-model-name-2022-02-15-04-21-04. + Suffix string `json:"suffix,omitempty"` +} + +type FineTuneResponse struct { + ID string `json:"id"` + Object objects.Object `json:"object"` + Model models.Completion `json:"model"` + CreatedAt uint64 `json:"created_at"` + Events []struct { + Object objects.Object `json:"object"` + CreatedAt uint64 `json:"created_at"` + Level string `json:"level"` + Message string `json:"message"` + } `json:"events"` + FineTunedModel *string `json:"fine_tuned_model"` + Hyperparams struct { + BatchSize int `json:"batch_size"` + LearningRateMultiplier float64 `json:"learning_rate_multiplier"` + NEpochs int `json:"n_epochs"` + PromptLossWeight float64 `json:"prompt_loss_weight"` + } `json:"hyperparams"` + OrganizationID string `json:"organization_id"` + ResultFiles []string `json:"result_files"` + Status string `json:"status"` + ValidationFiles []string `json:"validation_files"` + TrainingFiles []struct { + ID string `json:"id"` + Object objects.Object `json:"object"` + Bytes int `json:"bytes"` + CreatedAt uint64 `json:"created_at"` + Filename string `json:"filename"` + Purpose string `json:"purpose"` + } `json:"training_files"` + UpdatedAt uint64 `json:"updated_at"` +} + +// CreateFineTune ... +func (c *Client) CreateFineTune(ctx context.Context, ftr *FineTuneRequest) (*FineTuneResponse, error) { + var b, err = c.post(ctx, routes.FineTunes, ftr) + if err != nil { + return nil, err + } + + var f *FineTuneResponse + if err = json.Unmarshal(b, f); err != nil { + return nil, err + } + + return f, nil +} diff --git a/objects/objects.go b/objects/objects.go index 3dc2855..e4d463e 100644 --- a/objects/objects.go +++ b/objects/objects.go @@ -24,6 +24,7 @@ const ( File // FineTune is a fine-tuned model. FineTune + FineTimeEvent // Engine represents an engine. // Deprecated: use Model instead. Engine @@ -61,6 +62,7 @@ var objectToString = map[Object]string{ Embedding: "embedding", File: "file", FineTune: "fine-tune", + FineTimeEvent: "fine-tune-event", Engine: "engine", } @@ -73,5 +75,6 @@ var stringToObject = map[string]Object{ "embedding": Embedding, "file": File, "fine-tune": FineTune, + "fine-tune-event": FineTimeEvent, "engine": Engine, } diff --git a/routes/routes.go b/routes/routes.go index d3945ad..79ee435 100644 --- a/routes/routes.go +++ b/routes/routes.go @@ -21,7 +21,10 @@ const ( // https://beta.openai.com/docs/api-reference/files Files = "files" - // + // FineTunes is the route for the fine-tunes endpoint. + // https://beta.openai.com/docs/api-reference/fine-tunes + FineTunes = "fines-tunes" + imagesBase = "images/" // ImageGenerations is the route for the create images endpoint. From 7d45e478e93942f5cdfe9d2a092f6665f43e3881 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Mon, 16 Jan 2023 08:01:05 -0800 Subject: [PATCH 15/30] implement rest of fine-tine endpoints --- client.go | 12 ++++-- files.go | 4 +- fine_tunes.go | 102 ++++++++++++++++++++++++++++++++++++++++++++------ 3 files changed, 102 insertions(+), 16 deletions(-) diff --git a/client.go b/client.go index dcf8c87..00a7ee1 100644 --- a/client.go +++ b/client.go @@ -192,20 +192,24 @@ func (c *Client) get(ctx context.Context, path string) ([]byte, error) { return io.ReadAll(resp.Body) } -func (c *Client) delete(ctx context.Context, path string) error { +func (c *Client) delete(ctx context.Context, path string) ([]byte, error) { var req, err = http.NewRequestWithContext(ctx, "DELETE", reqURL(path), nil) if err != nil { - return err + return nil, err } var resp *http.Response resp, err = http.DefaultClient.Do(req) if err != nil { - return err + return nil, err } defer resp.Body.Close() - return interpretResponse(resp) + if err = interpretResponse(resp); err != nil { + return nil, err + } + + return io.ReadAll(resp.Body) } func reqURL(route string) string { diff --git a/files.go b/files.go index 4d860dc..647922f 100644 --- a/files.go +++ b/files.go @@ -44,7 +44,9 @@ func (c *Client) CreateFile(ctx context.Context, fr *FileRequest) (*File, error) // DeleteFile deletes an existing file. func (c *Client) DeleteFile(ctx context.Context, id string) error { - return c.delete(ctx, path.Join(routes.Files, id)) + var _, err = c.delete(ctx, path.Join(routes.Files, id)) + + return err } // ListFiles Lists the currently available files, diff --git a/fine_tunes.go b/fine_tunes.go index 4b1b527..328cb34 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -6,6 +6,7 @@ import ( "github.com/fabiustech/openai/models" "github.com/fabiustech/openai/objects" "github.com/fabiustech/openai/routes" + "path" ) // FineTuneRequest ... @@ -89,18 +90,20 @@ type FineTuneRequest struct { Suffix string `json:"suffix,omitempty"` } +type Event struct { + Object objects.Object `json:"object"` + CreatedAt uint64 `json:"created_at"` + Level string `json:"level"` + Message string `json:"message"` +} + type FineTuneResponse struct { - ID string `json:"id"` - Object objects.Object `json:"object"` - Model models.Completion `json:"model"` - CreatedAt uint64 `json:"created_at"` - Events []struct { - Object objects.Object `json:"object"` - CreatedAt uint64 `json:"created_at"` - Level string `json:"level"` - Message string `json:"message"` - } `json:"events"` - FineTunedModel *string `json:"fine_tuned_model"` + ID string `json:"id"` + Object objects.Object `json:"object"` + Model models.Completion `json:"model"` + CreatedAt uint64 `json:"created_at"` + Events []*Event `json:"events,omitempty"` + FineTunedModel *string `json:"fine_tuned_model"` Hyperparams struct { BatchSize int `json:"batch_size"` LearningRateMultiplier float64 `json:"learning_rate_multiplier"` @@ -122,6 +125,12 @@ type FineTuneResponse struct { UpdatedAt uint64 `json:"updated_at"` } +type FineTuneDeletionResponse struct { + ID string `json:"id"` + Object objects.Object `json:"object"` + Deleted bool `json:"deleted"` +} + // CreateFineTune ... func (c *Client) CreateFineTune(ctx context.Context, ftr *FineTuneRequest) (*FineTuneResponse, error) { var b, err = c.post(ctx, routes.FineTunes, ftr) @@ -136,3 +145,74 @@ func (c *Client) CreateFineTune(ctx context.Context, ftr *FineTuneRequest) (*Fin return f, nil } + +func (c *Client) ListFineTunes(ctx context.Context) (*List[*FineTuneResponse], error) { + var b, err = c.get(ctx, routes.FineTunes) + if err != nil { + return nil, err + } + + var l *List[*FineTuneResponse] + if err = json.Unmarshal(b, l); err != nil { + return nil, err + } + + return l, nil +} + +func (c *Client) RetrieveFineTune(ctx context.Context, id string) (*FineTuneResponse, error) { + var b, err = c.get(ctx, path.Join(routes.FineTunes, id)) + if err != nil { + return nil, err + } + + var f *FineTuneResponse + if err = json.Unmarshal(b, f); err != nil { + return nil, err + } + + return f, nil +} + +func (c *Client) CancelFineTune(ctx context.Context, id string) (*FineTuneResponse, error) { + var b, err = c.post(ctx, path.Join(routes.FineTunes, id, "cancel"), nil) + if err != nil { + return nil, err + } + + var f *FineTuneResponse + if err = json.Unmarshal(b, f); err != nil { + return nil, err + } + + return f, nil +} + +// TODO: Support streaming (maybe different method). +func (c *Client) ListFineTuneEvents(ctx context.Context, id string) (*List[*Event], error) { + var b, err = c.get(ctx, path.Join(routes.FineTunes, id, "events")) + if err != nil { + return nil, err + } + + var l *List[*Event] + if err = json.Unmarshal(b, l); err != nil { + return nil, err + } + + return l, nil +} + +func (c *Client) DeleteFineTune(ctx context.Context, id string) (*FineTuneDeletionResponse, error) { + var b, err = c.delete(ctx, path.Join(routes.FineTunes, id)) + if err != nil { + return nil, err + } + + var f *FineTuneDeletionResponse + if err = json.Unmarshal(b, f); err != nil { + return nil, err + } + + return f, nil +} From 3533cd55e95428406e93d149c2c96e1404763e4d Mon Sep 17 00:00:00 2001 From: Andy Day Date: Mon, 16 Jan 2023 08:03:39 -0800 Subject: [PATCH 16/30] rename --- fine_tunes.go | 14 +++++++------- models/completions.go | 23 ----------------------- models/fine_tunes.go | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 30 deletions(-) create mode 100644 models/fine_tunes.go diff --git a/fine_tunes.go b/fine_tunes.go index 328cb34..e336dfb 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -38,7 +38,7 @@ type FineTuneRequest struct { // "davinci", or a fine-tuned model created after 2022-04-21. To learn more about these models, see the Models // documentation. // Defaults to "curie". - Model *models.Completion `json:"model,omitempty"` + Model *models.FineTune `json:"model,omitempty"` // NEpochs specifies the number of epochs to train the model for. An epoch refers to one full cycle through // the training dataset. // Defaults to 4. @@ -98,12 +98,12 @@ type Event struct { } type FineTuneResponse struct { - ID string `json:"id"` - Object objects.Object `json:"object"` - Model models.Completion `json:"model"` - CreatedAt uint64 `json:"created_at"` - Events []*Event `json:"events,omitempty"` - FineTunedModel *string `json:"fine_tuned_model"` + ID string `json:"id"` + Object objects.Object `json:"object"` + Model models.FineTune `json:"model"` + CreatedAt uint64 `json:"created_at"` + Events []*Event `json:"events,omitempty"` + FineTunedModel *string `json:"fine_tuned_model"` Hyperparams struct { BatchSize int `json:"batch_size"` LearningRateMultiplier float64 `json:"learning_rate_multiplier"` diff --git a/models/completions.go b/models/completions.go index ce40670..23f6d35 100644 --- a/models/completions.go +++ b/models/completions.go @@ -49,21 +49,6 @@ const ( // English instructions. CurieInstructBeta - // Davinci most capable of the older versions of the GPT-3 models - // and is intended to be used with the fine-tuning endpoints. - Davinci - // Curie is very capable, but faster and lower cost than Davinci. It is - // an older version of the GPT-3 models and is intended to be used with - // the fine-tuning endpoints. - Curie - // Babbage is capable of straightforward tasks, very fast, and lower cost. - // It is an older version of the GPT-3 models and is intended to be used - // with the fine-tuning endpoints. - Babbage - // Ada is capable of very simple tasks, usually the fastest model in the - // GPT-3 series, and lowest cost. It is an older version of the GPT-3 - // models and is intended to be used with the fine-tuning endpoints. - Ada // CodeDavinci002 is the most capable Codex model. Particularly good at // translating natural language to code. In addition to completing code, // also supports inserting completions within code. @@ -124,10 +109,6 @@ var completionToString = map[Completion]string{ TextDavinci001: "text-davinci-001", DavinciInstructBeta: "davinci-instruct-beta", CurieInstructBeta: "curie-instruct-beta", - Davinci: "davinci", - Curie: "curie", - Ada: "ada", - Babbage: "babbage", CodeDavinci002: "code-davinci-002", CodeCushman001: "code-cushman-001", CodeDavinci001: "code-davinci-001", @@ -144,10 +125,6 @@ var stringToCompletion = map[string]Completion{ "text-davinci-001": TextDavinci001, "davinci-instruct-beta": DavinciInstructBeta, "curie-instruct-beta": CurieInstructBeta, - "davinci": Davinci, - "curie": Curie, - "ada": Ada, - "babbage": Babbage, "code-davinci-002": CodeDavinci002, "code-cushman-001": CodeCushman001, "code-davinci-001": CodeDavinci001, diff --git a/models/fine_tunes.go b/models/fine_tunes.go new file mode 100644 index 0000000..dcca0cc --- /dev/null +++ b/models/fine_tunes.go @@ -0,0 +1,35 @@ +package models + +type FineTune int + +const ( + // Davinci most capable of the older versions of the GPT-3 models + // and is intended to be used with the fine-tuning endpoints. + Davinci FineTune = iota + // Curie is very capable, but faster and lower cost than Davinci. It is + // an older version of the GPT-3 models and is intended to be used with + // the fine-tuning endpoints. + Curie + // Babbage is capable of straightforward tasks, very fast, and lower cost. + // It is an older version of the GPT-3 models and is intended to be used + // with the fine-tuning endpoints. + Babbage + // Ada is capable of very simple tasks, usually the fastest model in the + // GPT-3 series, and lowest cost. It is an older version of the GPT-3 + // models and is intended to be used with the fine-tuning endpoints. + Ada +) + +var fineTuneToString = map[FineTune]string{ + Davinci: "davinci", + Curie: "curie", + Ada: "ada", + Babbage: "babbage", +} + +var stringToFineTune = map[string]FineTune{ + "davinci": Davinci, + "curie": Curie, + "ada": Ada, + "babbage": Babbage, +} From f1d08517270d7436e6183c888e87f918b9e5681f Mon Sep 17 00:00:00 2001 From: Andy Day Date: Mon, 16 Jan 2023 08:05:01 -0800 Subject: [PATCH 17/30] add enum methods --- models/fine_tunes.go | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/models/fine_tunes.go b/models/fine_tunes.go index dcca0cc..4e4caac 100644 --- a/models/fine_tunes.go +++ b/models/fine_tunes.go @@ -3,9 +3,10 @@ package models type FineTune int const ( + UnknownFineTune FineTune = iota // Davinci most capable of the older versions of the GPT-3 models // and is intended to be used with the fine-tuning endpoints. - Davinci FineTune = iota + Davinci // Curie is very capable, but faster and lower cost than Davinci. It is // an older version of the GPT-3 models and is intended to be used with // the fine-tuning endpoints. @@ -20,6 +21,29 @@ const ( Ada ) +// String implements the fmt.Stringer interface. +func (f FineTune) String() string { + return fineTuneToString[f] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (f FineTune) MarshalText() ([]byte, error) { + return []byte(f.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (f *FineTune) UnmarshalText(b []byte) error { + if val, ok := stringToFineTune[(string(b))]; ok { + *f = val + return nil + } + + *f = UnknownFineTune + + return nil +} + var fineTuneToString = map[FineTune]string{ Davinci: "davinci", Curie: "curie", From e9ce51b491ea7d7970653146cc1991a2e2d239b1 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Mon, 16 Jan 2023 08:10:48 -0800 Subject: [PATCH 18/30] update readme --- README.md | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 6d9b09e..01dc530 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,14 @@ -# go-gpt3 -[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/sashabaranov/go-gpt3) -[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-gpt3)](https://goreportcard.com/report/github.com/sashabaranov/go-gpt3) +# openai +[![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/fabiustech/openai) +[![Go Report Card](https://goreportcard.com/badge/github.com/sashabaranov/go-gpt3)](https://goreportcard.com/report/github.com/fabiustech/openai) - -[OpenAI GPT-3](https://beta.openai.com/) API wrapper for Go +Zero dependency Go Client for [OpenAI](https://beta.openai.com/) API endpoints. Built upon the great work done [here](https://github.com/sashabaranov/go-gpt3). Installation: ``` go get github.com/fabiustech/openai ``` - Example usage: ```go @@ -21,21 +19,21 @@ import ( "fmt" "github.com/fabiustech/openai" + "github.com/fabiustech/openai/models" ) func main() { - c := openai.NewClient("your token") - ctx := context.Background() - - req := openai.CompletionRequest{ - Model: openai.GPT3Ada, - MaxTokens: 5, + var c = openai.NewClient("your token") + + var resp, err = c.CreateCompletion(context.Background(), &openai.CompletionRequest{ + Model: models.TextDavinci003, + MaxTokens: 100, Prompt: "Lorem ipsum", - } - resp, err := c.CreateCompletion(ctx, req) + }) if err != nil { return } + fmt.Println(resp.Choices[0].Text) } ``` From afcf646aa9917dc1ac35e08716d8d543e7a9aa82 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Tue, 17 Jan 2023 10:05:48 -0800 Subject: [PATCH 19/30] good --- client.go | 2 +- completions.go | 6 +----- error.go | 2 +- files.go | 28 +++++++++++++++------------- images.go | 36 ++++++++++++++++++++++++++++-------- common.go => usage.go | 0 6 files changed, 46 insertions(+), 28 deletions(-) rename common.go => usage.go (100%) diff --git a/client.go b/client.go index 00a7ee1..c7c8230 100644 --- a/client.go +++ b/client.go @@ -95,7 +95,7 @@ func (c *Client) postFile(ctx context.Context, fr *FileRequest) ([]byte, error) } var fw io.Writer - fw, err = w.CreateFormFile("file", fr.FileName) + fw, err = w.CreateFormFile("file", fr.File) if err != nil { return nil, err } diff --git a/completions.go b/completions.go index 16fd27a..fc9dfa7 100644 --- a/completions.go +++ b/completions.go @@ -116,11 +116,7 @@ type CompletionResponse struct { Usage *Usage `json:"usage"` } -// CreateCompletion — API call to create a completion. This is the main endpoint of the API. Returns new text as well -// as, if requested, the probabilities over each alternative token at each position. -// -// If using a fine-tuned model, simply provide the model's ID in the CompletionRequest object, -// and the server will use the model's parameters to generate the completion. +// CreateCompletion ... func (c *Client) CreateCompletion(ctx context.Context, cr *CompletionRequest) (*CompletionResponse, error) { var b, err = c.post(ctx, routes.Completions, cr) if err != nil { diff --git a/error.go b/error.go index 7d34401..0bbadf9 100644 --- a/error.go +++ b/error.go @@ -10,7 +10,7 @@ type ErrorResponse struct { } type Error struct { - Code int `json:"code,omitempty"` + Code int `json:"code"` Message string `json:"message"` Param *string `json:"param,omitempty"` Type string `json:"type"` diff --git a/files.go b/files.go index 647922f..357ccdd 100644 --- a/files.go +++ b/files.go @@ -10,24 +10,28 @@ import ( // FileRequest ... type FileRequest struct { - FileName string `json:"file"` + // File is the JSON Lines file to be uploaded. If the purpose is set to "fine-tune", each line is a JSON record + // with "prompt" and "completion" fields representing your training examples: + // https://beta.openai.com/docs/guides/fine-tuning/prepare-training-data. + File string `json:"file"` FilePath string `json:"-"` - Purpose string `json:"purpose"` + // The intended purpose of the uploaded documents. Use "fine-tune" for Fine-tuning. + // This allows OpenAI to validate the format of the uploaded file. + Purpose string `json:"purpose"` } // File struct represents an OpenAPI file. type File struct { - Bytes int `json:"bytes"` - CreatedAt int `json:"created_at"` ID string `json:"id"` - FileName string `json:"filename"` Object objects.Object `json:"object"` - Owner string `json:"owner"` + Bytes int `json:"bytes"` + CreatedAt int `json:"created_at"` + Filename string `json:"filename"` Purpose string `json:"purpose"` } -// CreateFile uploads a jsonl file to GPT3 -// FilePath can be either a local file path or a URL. +// TODO: FileRequest should take a file.File. +// CreateFile ... func (c *Client) CreateFile(ctx context.Context, fr *FileRequest) (*File, error) { var b, err = c.postFile(ctx, fr) if err != nil { @@ -42,15 +46,14 @@ func (c *Client) CreateFile(ctx context.Context, fr *FileRequest) (*File, error) return f, nil } -// DeleteFile deletes an existing file. +// DeleteFile ... func (c *Client) DeleteFile(ctx context.Context, id string) error { var _, err = c.delete(ctx, path.Join(routes.Files, id)) return err } -// ListFiles Lists the currently available files, -// and provides basic information about each file such as the file name and purpose. +// ListFiles ... func (c *Client) ListFiles(ctx context.Context) (*List[*File], error) { var b, err = c.get(ctx, routes.Files) if err != nil { @@ -65,8 +68,7 @@ func (c *Client) ListFiles(ctx context.Context) (*List[*File], error) { return fl, nil } -// GetFile Retrieves a file instance, providing basic information about the file -// such as the file name and purpose. +// GetFile ... func (c *Client) GetFile(ctx context.Context, id string) (*File, error) { var b, err = c.get(ctx, path.Join(routes.Files, id)) if err != nil { diff --git a/images.go b/images.go index 6c1f223..7363cbd 100644 --- a/images.go +++ b/images.go @@ -7,29 +7,49 @@ import ( "github.com/fabiustech/openai/routes" ) -// CreateImageRequest represents the request structure for the image API. +// CreateImageRequest contains all relevant fields for requests to the images/generations endpoint. type CreateImageRequest struct { + // Prompt is a text description of the desired image(s). The maximum length is 1000 characters. Prompt string `json:"prompt"` *ImageRequestFields } +// EditImageRequest contains all relevant fields for requests to the images/edits endpoint. type EditImageRequest struct { - Image string `json:"image"` - Mask *string `json:"mask,omitempty"` - Prompt string `json:"prompt"` + // Image is the image to edit. Must be a valid PNG file, less than 4MB, and square. If Mask is not provided, image + // must have transparency, which will be used as the mask. + Image string `json:"image"` + // Mask is an additional image whose fully transparent areas (e.g. where alpha is zero) indicate where image should + // be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as Image. + Mask string `json:"mask,omitempty"` + // Prompt is a text description of the desired image(s). The maximum length is 1000 characters. + Prompt string `json:"prompt"` *ImageRequestFields } +// VariationImageRequest contains all relevant fields for requests to the images/variations endpoint. type VariationImageRequest struct { + // Image is the image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square. Image string `json:"image"` *ImageRequestFields } +// ImageRequestFields contains the common fields for all images endpoints. type ImageRequestFields struct { - N *int `json:"n,omitempty"` - Size *images.Size `json:"size,omitempty"` - ResponseFormat *images.Format `json:"response_format,omitempty"` - User *string `json:"user,omitempty"` + // N specifies the number of images to generate. Must be between 1 and 10. + // Defaults to 1. + N int `json:"n,omitempty"` + // Size specifies the size of the generated images. Must be one of images.Size256x256, images.Size512x512, or + // images.Size1024x1024. + // Defaults to images.Size1024x1024. + Size images.Size `json:"size,omitempty"` + // ResponseFormat specifies the format in which the generated images are returned. Must be one of images.FormatURL + // or images.FormatB64JSON. + // Defaults to images.FormatURL. + ResponseFormat images.Format `json:"response_format,omitempty"` + // User specifies a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse: + // https://beta.openai.com/docs/guides/safety-best-practices/end-user-ids. + User string `json:"user,omitempty"` } // ImageResponse represents a response structure for image API. diff --git a/common.go b/usage.go similarity index 100% rename from common.go rename to usage.go From dffd7b964589d4b93fade975305f2510bcd80eb6 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Tue, 17 Jan 2023 10:16:01 -0800 Subject: [PATCH 20/30] . --- client.go | 3 ++- client_test.go | 23 +++++++++++---------- completions.go | 1 + edits.go | 1 + embeddings.go | 1 + engines.go | 3 ++- files.go | 3 ++- fine_tunes.go | 3 ++- images.go | 1 + models/moderations.go | 48 +++++++++++++++++++++++++++++++++++++++++++ moderation.go | 10 +++++++-- 11 files changed, 80 insertions(+), 17 deletions(-) create mode 100644 models/moderations.go diff --git a/client.go b/client.go index c7c8230..8b79ff8 100644 --- a/client.go +++ b/client.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/fabiustech/openai/routes" "io" "mime/multipart" "net/http" @@ -13,6 +12,8 @@ import ( "os" "path" "strings" + + "github.com/fabiustech/openai/routes" ) const ( diff --git a/client_test.go b/client_test.go index d89a6c1..813fcec 100644 --- a/client_test.go +++ b/client_test.go @@ -5,11 +5,6 @@ import ( "context" "encoding/json" "fmt" - "github.com/fabiustech/openai" - "github.com/fabiustech/openai/images" - "github.com/fabiustech/openai/models" - "github.com/fabiustech/openai/objects" - "github.com/fabiustech/openai/params" "io" "log" "net/http" @@ -19,6 +14,12 @@ import ( "strings" "testing" "time" + + "github.com/fabiustech/openai" + "github.com/fabiustech/openai/images" + "github.com/fabiustech/openai/models" + "github.com/fabiustech/openai/objects" + "github.com/fabiustech/openai/params" ) const ( @@ -49,8 +50,8 @@ func TestAPI(t *testing.T) { t.Fatalf("ListFiles error: %v", err) } - if len(fileRes.Files) > 0 { - _, err = c.GetFile(ctx, fileRes.Files[0].ID) + if len(fileRes.Data) > 0 { + _, err = c.GetFile(ctx, fileRes.Data[0].ID) if err != nil { t.Fatalf("GetFile error: %v", err) } @@ -83,7 +84,7 @@ func TestCompletions(t *testing.T) { req := &openai.CompletionRequest{ MaxTokens: params.Optional(5), - Model: models.Ada, + Model: models.TextDavinci003, } req.Prompt = "Lorem ipsum" _, err = client.CreateCompletion(ctx, req) @@ -282,12 +283,12 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { res := &openai.ImageResponse{ Created: uint64(time.Now().Unix()), } - for i := 0; i < *imageReq.N; i++ { + for i := 0; i < imageReq.N; i++ { var imageData = &openai.ImageData{} switch imageReq.ResponseFormat { - case params.Optional(images.FormatURL), nil: + case images.FormatURL: imageData.URL = params.Optional("https://example.com/image.png") - case params.Optional(images.FormatB64JSON): + case images.FormatB64JSON: // This decodes to "{}" in base64. imageData.B64JSON = params.Optional("e30K") default: diff --git a/completions.go b/completions.go index fc9dfa7..1ae7c5d 100644 --- a/completions.go +++ b/completions.go @@ -3,6 +3,7 @@ package openai import ( "context" "encoding/json" + "github.com/fabiustech/openai/models" "github.com/fabiustech/openai/objects" "github.com/fabiustech/openai/routes" diff --git a/edits.go b/edits.go index 54a2d8d..cad00dc 100644 --- a/edits.go +++ b/edits.go @@ -3,6 +3,7 @@ package openai import ( "context" "encoding/json" + "github.com/fabiustech/openai/models" "github.com/fabiustech/openai/objects" "github.com/fabiustech/openai/routes" diff --git a/embeddings.go b/embeddings.go index 29fc24d..b10170e 100644 --- a/embeddings.go +++ b/embeddings.go @@ -3,6 +3,7 @@ package openai import ( "context" "encoding/json" + "github.com/fabiustech/openai/models" "github.com/fabiustech/openai/objects" "github.com/fabiustech/openai/routes" diff --git a/engines.go b/engines.go index 511e5e0..70848c5 100644 --- a/engines.go +++ b/engines.go @@ -3,8 +3,9 @@ package openai import ( "context" "encoding/json" - "github.com/fabiustech/openai/routes" "path" + + "github.com/fabiustech/openai/routes" ) // Engine struct represents engine from OpenAPI API. diff --git a/files.go b/files.go index 357ccdd..4b35f9f 100644 --- a/files.go +++ b/files.go @@ -3,9 +3,10 @@ package openai import ( "context" "encoding/json" + "path" + "github.com/fabiustech/openai/objects" "github.com/fabiustech/openai/routes" - "path" ) // FileRequest ... diff --git a/fine_tunes.go b/fine_tunes.go index e336dfb..637d769 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -3,10 +3,11 @@ package openai import ( "context" "encoding/json" + "path" + "github.com/fabiustech/openai/models" "github.com/fabiustech/openai/objects" "github.com/fabiustech/openai/routes" - "path" ) // FineTuneRequest ... diff --git a/images.go b/images.go index 7363cbd..603b816 100644 --- a/images.go +++ b/images.go @@ -3,6 +3,7 @@ package openai import ( "context" "encoding/json" + "github.com/fabiustech/openai/images" "github.com/fabiustech/openai/routes" ) diff --git a/models/moderations.go b/models/moderations.go new file mode 100644 index 0000000..fa65c7d --- /dev/null +++ b/models/moderations.go @@ -0,0 +1,48 @@ +package models + +// Moderation represents all models available for use with the Moderations endpoint. +type Moderation int + +const ( + // UnknownModeration represents and invalid Moderation model. + UnknownModeration Moderation = iota + // TextModerationStable ... + TextModerationStable + // TextModerationLatest ... + TextModerationLatest +) + +// String implements the fmt.Stringer interface. +func (m Moderation) String() string { + return moderationToString[m] +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (m Moderation) MarshalText() ([]byte, error) { + return []byte(m.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface. +// On unrecognized value, it sets |e| to Unknown. +func (m *Moderation) UnmarshalText(b []byte) error { + if val, ok := stringToModeration[(string(b))]; ok { + *m = val + return nil + } + + *m = UnknownModeration + + return nil +} + +var moderationToString = map[Moderation]string{ + // TextDavinciEdit001 can be used to edit text, rather than just completing it. + TextModerationStable: "text-moderation-stable", + // CodeDavinciEdit001 can be used to edit code, rather than just completing it. + TextModerationLatest: "text-moderation-latest", +} + +var stringToModeration = map[string]Moderation{ + "text-moderation-stable": TextModerationStable, + "text-moderation-latest": TextModerationLatest, +} diff --git a/moderation.go b/moderation.go index d5846c0..91568bb 100644 --- a/moderation.go +++ b/moderation.go @@ -3,13 +3,19 @@ package openai import ( "context" "encoding/json" + + "github.com/fabiustech/openai/models" + "github.com/fabiustech/openai/routes" ) // ModerationRequest represents a request structure for moderation API. type ModerationRequest struct { - Input string `json:"input,omitempty"` - Model *string `json:"model,omitempty"` + // Input is the input text to classify. + Input string `json:"input,omitempty"` + // Model specifies the model to use for moderation. + // Defaults to models.TextModerationLatest. + Model models.Moderation `json:"model,omitempty"` } // Result represents one of possible moderation results. From 278ed8a4dc233637c356f286d0092e49804c9cb4 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Wed, 18 Jan 2023 13:01:41 -0800 Subject: [PATCH 21/30] . --- client.go | 96 ++++++++++------------------------ client_test.go | 139 +++++++++++++++++++++++++------------------------ completions.go | 8 +-- edits.go | 8 +-- embeddings.go | 2 +- engines.go | 2 +- error.go | 6 ++- files.go | 28 +++++++--- fine_tunes.go | 7 ++- list.go | 5 +- moderation.go | 2 +- 11 files changed, 143 insertions(+), 160 deletions(-) diff --git a/client.go b/client.go index 8b79ff8..c9c8116 100644 --- a/client.go +++ b/client.go @@ -9,9 +9,7 @@ import ( "mime/multipart" "net/http" "net/url" - "os" "path" - "strings" "github.com/fabiustech/openai/routes" ) @@ -26,20 +24,27 @@ const ( type Client struct { token string orgID *string + + // scheme and host are only used for testing. + scheme, host string } // NewClient creates new OpenAI API client. func NewClient(token string) *Client { return &Client{ - token: token, + token: token, + scheme: scheme, + host: host, } } // NewClientWithOrg creates new OpenAI API client for specified Organization ID. func NewClientWithOrg(token, org string) *Client { return &Client{ - token: token, - orgID: &org, + token: token, + orgID: &org, + scheme: scheme, + host: host, } } @@ -50,17 +55,12 @@ func (c *Client) post(ctx context.Context, path string, payload any) ([]byte, er } var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", reqURL(path), bytes.NewBuffer(b)) + req, err = http.NewRequestWithContext(ctx, "POST", c.reqURL(path), bytes.NewBuffer(b)) if err != nil { return nil, err } - switch payload.(type) { - case FileRequest: - req.Header.Set("Content-Type", "") // TODO - default: - req.Header.Set("Content-Type", "application/json; charset=utf-8") - } + req.Header.Set("Content-Type", "application/json; charset=utf-8") if c.orgID != nil { req.Header.Set("OpenAI-Organization", *c.orgID) @@ -80,48 +80,39 @@ func (c *Client) post(ctx context.Context, path string, payload any) ([]byte, er return io.ReadAll(resp.Body) } -// TODO: improve this. func (c *Client) postFile(ctx context.Context, fr *FileRequest) ([]byte, error) { var b bytes.Buffer var w = multipart.NewWriter(&b) - var pw, err = w.CreateFormField("purpose") - if err != nil { - return nil, err - } - - _, err = io.Copy(pw, strings.NewReader(fr.Purpose)) - if err != nil { + if err := w.WriteField("purposes", fr.Purpose); err != nil { return nil, err } - var fw io.Writer - fw, err = w.CreateFormFile("file", fr.File) + var fw, err = w.CreateFormFile("file", fr.File.Name()) if err != nil { return nil, err } - var file io.ReadCloser - file, err = readFile(fr.FilePath) - if err != nil { + if _, err = io.Copy(fw, fr.File); err != nil { return nil, err } - defer file.Close() - if _, err = io.Copy(fw, file); err != nil { + if err = w.Close(); err != nil { return nil, err } - w.Close() - var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", reqURL(routes.Files), &b) + req, err = http.NewRequestWithContext(ctx, "POST", c.reqURL(routes.Files), &b) if err != nil { return nil, err } req.Header.Set("Content-Type", w.FormDataContentType()) + if c.orgID != nil { + req.Header.Set("OpenAI-Organization", *c.orgID) + } + var resp *http.Response resp, err = http.DefaultClient.Do(req) if err != nil { @@ -136,41 +127,8 @@ func (c *Client) postFile(ctx context.Context, fr *FileRequest) ([]byte, error) return io.ReadAll(resp.Body) } -func readFile(path string) (io.ReadCloser, error) { - if !isURL(path) { - return os.Open(path) - } - - var resp, err = http.Get(path) - if err != nil { - return nil, err - } - - // Check server response. - if resp.StatusCode != http.StatusOK { - _ = resp.Body.Close() - return nil, fmt.Errorf("error, status code: %d, message: failed to fetch file", resp.StatusCode) - } - - return resp.Body, nil -} - -// isUrl is a helper function that determines whether the given FilePath -// is a remote URL or a local file path. -func isURL(path string) bool { - if _, err := url.ParseRequestURI(path); err != nil { - return false - } - - if u, err := url.Parse(path); err != nil || u.Scheme == "" || u.Host == "" { - return false - } - - return true -} - func (c *Client) get(ctx context.Context, path string) ([]byte, error) { - var req, err = http.NewRequestWithContext(ctx, "POST", reqURL(path), nil) + var req, err = http.NewRequestWithContext(ctx, "POST", c.reqURL(path), nil) if err != nil { return nil, err } @@ -194,7 +152,7 @@ func (c *Client) get(ctx context.Context, path string) ([]byte, error) { } func (c *Client) delete(ctx context.Context, path string) ([]byte, error) { - var req, err = http.NewRequestWithContext(ctx, "DELETE", reqURL(path), nil) + var req, err = http.NewRequestWithContext(ctx, "DELETE", c.reqURL(path), nil) if err != nil { return nil, err } @@ -213,10 +171,10 @@ func (c *Client) delete(ctx context.Context, path string) ([]byte, error) { return io.ReadAll(resp.Body) } -func reqURL(route string) string { +func (c *Client) reqURL(route string) string { var u = &url.URL{ - Scheme: scheme, - Host: host, + Scheme: c.scheme, + Host: c.host, Path: path.Join(basePath, route), } return u.String() @@ -228,7 +186,7 @@ func interpretResponse(resp *http.Response) error { if err != nil { return fmt.Errorf("error, status code: %d", resp.StatusCode) } - var er *ErrorResponse + var er *errorResponse if err = json.Unmarshal(b, er); err != nil || er.Error == nil { return fmt.Errorf("error, status code: %d, msg: %s", resp.StatusCode, string(b)) } diff --git a/client_test.go b/client_test.go index 813fcec..e47d550 100644 --- a/client_test.go +++ b/client_test.go @@ -1,4 +1,4 @@ -package openai_test +package openai import ( "bytes" @@ -15,27 +15,31 @@ import ( "testing" "time" - "github.com/fabiustech/openai" "github.com/fabiustech/openai/images" "github.com/fabiustech/openai/models" "github.com/fabiustech/openai/objects" "github.com/fabiustech/openai/params" ) +/* +This test suite has been ported from the original repo: + +TODO: Cover all endpoints. +*/ + const ( - testAPIToken = "this-is-my-secure-token-do-not-steal!!" + testToken = "this-is-my-secure-token-do-not-steal!!" ) func TestAPI(t *testing.T) { - apiToken := os.Getenv("OPENAI_TOKEN") - if apiToken == "" { + var token, ok = os.LookupEnv("OPENAI_TOKEN") + if !ok { t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.") } - var err error - c := openai.NewClient(apiToken) - ctx := context.Background() - _, err = c.ListEngines(ctx) + var c = NewClient(token) + var ctx = context.Background() + var _, err = c.ListEngines(ctx) if err != nil { t.Fatalf("ListEngines error: %v", err) } @@ -45,49 +49,53 @@ func TestAPI(t *testing.T) { t.Fatalf("GetEngine error: %v", err) } - fileRes, err := c.ListFiles(ctx) + var fl *List[*File] + fl, err = c.ListFiles(ctx) if err != nil { t.Fatalf("ListFiles error: %v", err) } - if len(fileRes.Data) > 0 { - _, err = c.GetFile(ctx, fileRes.Data[0].ID) + if len(fl.Data) > 0 { + _, err = c.GetFile(ctx, fl.Data[0].ID) if err != nil { t.Fatalf("GetFile error: %v", err) } - } // else skip + } - embeddingReq := &openai.EmbeddingRequest{ + _, err = c.CreateEmbeddings(ctx, &EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", }, - Model: models.AdaSearchQuery, - } - _, err = c.CreateEmbeddings(ctx, embeddingReq) + Model: models.AdaEmbeddingV2, + }) if err != nil { t.Fatalf("Embedding error: %v", err) } } +func newTestClient(host string) *Client { + return &Client{ + token: testToken, + host: host, + scheme: "http", + } +} + // TestCompletions Tests the completions endpoint of the API using the mocked server. func TestCompletions(t *testing.T) { - // create the test server - var err error - ts := OpenAITestServer() + var ts = OpenAITestServer() ts.Start() defer ts.Close() - client := openai.NewClient(testAPIToken) + var client = newTestClient(ts.URL) ctx := context.Background() - // client.BaseURL = ts.URL + "/v1" - req := &openai.CompletionRequest{ - MaxTokens: params.Optional(5), + var _, err = client.CreateCompletion(ctx, &CompletionRequest{ + Prompt: "Lorem ipsum", Model: models.TextDavinci003, - } - req.Prompt = "Lorem ipsum" - _, err = client.CreateCompletion(ctx, req) + MaxTokens: 5, + }) if err != nil { t.Fatalf("CreateCompletion error: %v", err) } @@ -95,32 +103,27 @@ func TestCompletions(t *testing.T) { // TestEdits Tests the edits endpoint of the API using the mocked server. func TestEdits(t *testing.T) { - // create the test server - var err error - ts := OpenAITestServer() + var ts = OpenAITestServer() ts.Start() defer ts.Close() - client := openai.NewClient(testAPIToken) + var client = newTestClient(ts.URL) ctx := context.Background() - // client.BaseURL = ts.URL + "/v1" - // create an edit request - - editReq := &openai.EditsRequest{ + var n = 3 + var resp, err = client.CreateEdit(ctx, &EditsRequest{ Model: models.TextDavinciEdit001, Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + " ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip" + " ex ea commodo consequat. Duis aute irure dolor in reprehe", Instruction: "test instruction", - N: params.Optional(3), - } - response, err := client.Edits(ctx, editReq) + N: n, + }) if err != nil { t.Fatalf("Edits error: %v", err) } - if len(response.Choices) != *editReq.N { + if len(resp.Choices) != n { t.Fatalf("edits does not properly return the correct number of choices") } } @@ -146,7 +149,7 @@ func TestEmbedding(t *testing.T) { models.AdaEmbeddingV2, } for _, model := range embeddedModels { - embeddingReq := &openai.EmbeddingRequest{ + embeddingReq := &EmbeddingRequest{ Input: []string{ "The food was delicious and the waiter", "Other examples of embedding request", @@ -166,8 +169,8 @@ func TestEmbedding(t *testing.T) { } // getEditBody Returns the body of the request to create an edit. -func getEditBody(r *http.Request) (*openai.EditsRequest, error) { - edit := &openai.EditsRequest{} +func getEditBody(r *http.Request) (*EditsRequest, error) { + edit := &EditsRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { @@ -189,29 +192,29 @@ func handleEditEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var editReq *openai.EditsRequest + var editReq *EditsRequest editReq, err = getEditBody(r) if err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } // create a response - res := &openai.EditsResponse{ + res := &EditsResponse{ Object: objects.Edit, Created: uint64(time.Now().Unix()), } // edit and calculate token usage editString := "edited by mocked OpenAI server :)" - inputTokens := numTokens(editReq.Input+editReq.Instruction) * *editReq.N - completionTokens := int(float32(len(editString))/4) * *editReq.N - for i := 0; i < *editReq.N; i++ { + inputTokens := numTokens(editReq.Input+editReq.Instruction) * editReq.N + completionTokens := int(float32(len(editString))/4) * editReq.N + for i := 0; i < editReq.N; i++ { // instruction will be hidden and only seen by OpenAI - res.Choices = append(res.Choices, &openai.EditsChoice{ + res.Choices = append(res.Choices, &EditsChoice{ Text: editReq.Input + editString, Index: i, }) } - res.Usage = &openai.Usage{ + res.Usage = &Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -229,12 +232,12 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var completionReq *openai.CompletionRequest + var completionReq *CompletionRequest if completionReq, err = getCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := &openai.CompletionResponse{ + res := &CompletionResponse{ ID: strconv.Itoa(int(time.Now().Unix())), Object: objects.TextCompletion, Created: uint64(time.Now().Unix()), @@ -244,20 +247,20 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { Model: completionReq.Model, } // create completions - for i := 0; i < *completionReq.N; i++ { + for i := 0; i < completionReq.N; i++ { // generate a random string of length completionReq.Length - completionStr := strings.Repeat("a", *completionReq.MaxTokens) + completionStr := strings.Repeat("a", completionReq.MaxTokens) if completionReq.Echo { completionStr = completionReq.Prompt + completionStr } - res.Choices = append(res.Choices, &openai.CompletionChoice{ + res.Choices = append(res.Choices, &CompletionChoice{ Text: completionStr, Index: i, }) } - inputTokens := numTokens(completionReq.Prompt) * *completionReq.N - completionTokens := *completionReq.MaxTokens * *completionReq.N - res.Usage = &openai.Usage{ + inputTokens := numTokens(completionReq.Prompt) * completionReq.N + completionTokens := completionReq.MaxTokens * completionReq.N + res.Usage = &Usage{ PromptTokens: inputTokens, CompletionTokens: completionTokens, TotalTokens: inputTokens + completionTokens, @@ -275,16 +278,16 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var imageReq *openai.CreateImageRequest + var imageReq *CreateImageRequest if imageReq, err = getImageBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := &openai.ImageResponse{ + res := &ImageResponse{ Created: uint64(time.Now().Unix()), } for i := 0; i < imageReq.N; i++ { - var imageData = &openai.ImageData{} + var imageData = &ImageData{} switch imageReq.ResponseFormat { case images.FormatURL: imageData.URL = params.Optional("https://example.com/image.png") @@ -302,8 +305,8 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { } // getCompletionBody Returns the body of the request to create a completion. -func getCompletionBody(r *http.Request) (*openai.CompletionRequest, error) { - var completion = &openai.CompletionRequest{} +func getCompletionBody(r *http.Request) (*CompletionRequest, error) { + var completion = &CompletionRequest{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { @@ -317,8 +320,8 @@ func getCompletionBody(r *http.Request) (*openai.CompletionRequest, error) { } // getImageBody Returns the body of the request to create a image. -func getImageBody(r *http.Request) (*openai.CreateImageRequest, error) { - var image = &openai.CreateImageRequest{} +func getImageBody(r *http.Request) (*CreateImageRequest, error) { + var image = &CreateImageRequest{} // read the request body var reqBody, err = io.ReadAll(r.Body) if err != nil { @@ -333,7 +336,7 @@ func getImageBody(r *http.Request) (*openai.CreateImageRequest, error) { // numTokens Returns the number of GPT-3 encoded tokens in the given text. // This function approximates based on the rule of thumb stated by OpenAI: -// https://beta.openai.com/tokenizer +// https://beta.com/tokenizer // // TODO: implement an actual tokenizer for GPT-3 and Codex (once available) func numTokens(s string) int { @@ -347,11 +350,11 @@ func TestImages(t *testing.T) { ts.Start() defer ts.Close() - client := openai.NewClient(testAPIToken) + client := NewClient(testToken) ctx := context.Background() // client.BaseURL = ts.URL + "/v1" - req := &openai.CreateImageRequest{} + req := &CreateImageRequest{} req.Prompt = "Lorem ipsum" _, err = client.CreateImage(ctx, req) if err != nil { @@ -365,7 +368,7 @@ func OpenAITestServer() *httptest.Server { log.Printf("received request at path %q\n", r.URL.Path) // check auth - if r.Header.Get("Authorization") != "Bearer "+testAPIToken { + if r.Header.Get("Authorization") != "Bearer "+testToken { w.WriteHeader(http.StatusUnauthorized) return } diff --git a/completions.go b/completions.go index 1ae7c5d..670799d 100644 --- a/completions.go +++ b/completions.go @@ -9,7 +9,7 @@ import ( "github.com/fabiustech/openai/routes" ) -// CompletionRequest represents a request structure for Completion API. +// CompletionRequest contains all relevant fields for requests to the completions endpoint. type CompletionRequest struct { // Model specifies the ID of the model to use. // See more here: https://beta.openai.com/docs/models/overview @@ -21,12 +21,12 @@ type CompletionRequest struct { Prompt string `json:"prompt,omitempty"` // Suffix specifies the suffix that comes after a completion of inserted text. // Defaults to null. - Suffix *string `json:"suffix,omitempty"` + Suffix string `json:"suffix,omitempty"` // MaxTokens specifies the maximum number of tokens to generate in the completion. The token count of your prompt plus // max_tokens cannot exceed the model's context length. Most models have a context length of 2048 tokens (except // for the newest models, which support 4096). // Defaults to 16. - MaxTokens *int `json:"max_tokens,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` // Temperature specifies what sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative // applications, and 0 (argmax sampling) for ones with a well-defined answer. OpenAI generally recommends altering // this or top_p but not both. @@ -42,7 +42,7 @@ type CompletionRequest struct { // Note: Because this parameter generates many completions, it can quickly consume your token quota. Use carefully // and ensure that you have reasonable settings for max_tokens and stop. // Defaults to 1. - N *int `json:"n,omitempty"` + N int `json:"n,omitempty"` // Steam specifies Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent // events as they become available, with the stream terminated by a data: [DONE] message. // Defaults to false. diff --git a/edits.go b/edits.go index cad00dc..7749248 100644 --- a/edits.go +++ b/edits.go @@ -9,7 +9,7 @@ import ( "github.com/fabiustech/openai/routes" ) -// EditsRequest represents a request structure for Edits API. +// EditsRequest contains all relevant fields for requests to the edits endpoint. type EditsRequest struct { Model models.Edit `json:"model"` // Input is the input text to use as a starting point for the edit. @@ -19,7 +19,7 @@ type EditsRequest struct { Instruction string `json:"instruction,omitempty"` // N specifies how many edits to generate for the input and instruction. // Defaults to 1. - N *int `json:"n,omitempty"` + N int `json:"n,omitempty"` // Temperature specifies what sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative // applications, and 0 (argmax sampling) for ones with a well-defined answer. OpenAI generally recommends altering // this or top_p but not both. @@ -47,8 +47,8 @@ type EditsResponse struct { Choices []*EditsChoice `json:"choices"` } -// Edits ... -func (c *Client) Edits(ctx context.Context, er *EditsRequest) (*EditsResponse, error) { +// CreateEdit ... +func (c *Client) CreateEdit(ctx context.Context, er *EditsRequest) (*EditsResponse, error) { var b, err = c.post(ctx, routes.Edits, er) var resp *EditsResponse diff --git a/embeddings.go b/embeddings.go index b10170e..c7ee87a 100644 --- a/embeddings.go +++ b/embeddings.go @@ -28,7 +28,7 @@ type EmbeddingResponse struct { Usage *Usage } -// EmbeddingRequest is the input to a Create embeddings request. +// EmbeddingRequest contains all relevant fields for requests to the embeddings endpoint. type EmbeddingRequest struct { // Input is a slice of strings for which you want to generate an Embedding vector. // Each input must not exceed 2048 tokens in length. diff --git a/engines.go b/engines.go index 70848c5..2f29434 100644 --- a/engines.go +++ b/engines.go @@ -8,7 +8,7 @@ import ( "github.com/fabiustech/openai/routes" ) -// Engine struct represents engine from OpenAPI API. +// Engine contains all relevant fields for requests to the engines endpoint. type Engine struct { ID string `json:"id"` Object string `json:"object"` diff --git a/error.go b/error.go index 0bbadf9..86e77e7 100644 --- a/error.go +++ b/error.go @@ -5,10 +5,12 @@ import ( "net/http" ) -type ErrorResponse struct { +// errorResponse wraps the returned error. +type errorResponse struct { Error *Error `json:"error,omitempty"` } +// Error represents an error response from the API. type Error struct { Code int `json:"code"` Message string `json:"message"` @@ -16,10 +18,12 @@ type Error struct { Type string `json:"type"` } +// Error implements the error interface. func (e *Error) Error() string { return fmt.Sprintf("Code: %v, Message: %s, Type: %s, Param: %v", e.Code, e.Message, e.Type, e.Param) } +// Retryable returns true if the error is retryable. func (e *Error) Retryable() bool { if e.Code >= http.StatusInternalServerError { return true diff --git a/files.go b/files.go index 4b35f9f..9609d07 100644 --- a/files.go +++ b/files.go @@ -3,6 +3,7 @@ package openai import ( "context" "encoding/json" + "os" "path" "github.com/fabiustech/openai/objects" @@ -14,14 +15,26 @@ type FileRequest struct { // File is the JSON Lines file to be uploaded. If the purpose is set to "fine-tune", each line is a JSON record // with "prompt" and "completion" fields representing your training examples: // https://beta.openai.com/docs/guides/fine-tuning/prepare-training-data. - File string `json:"file"` - FilePath string `json:"-"` - // The intended purpose of the uploaded documents. Use "fine-tune" for Fine-tuning. + File *os.File + // Purpose is the intended purpose of the uploaded documents. Use "fine-tune" for Fine-tuning. // This allows OpenAI to validate the format of the uploaded file. - Purpose string `json:"purpose"` + Purpose string } -// File struct represents an OpenAPI file. +// NewFineTuneFileRequest returns a |*FileRequest| with File opened from |path| and Purpose set to "fine-tuned". +func NewFineTuneFileRequest(path string) (*FileRequest, error) { + var f, err = os.Open(path) + if err != nil { + return nil, err + } + + return &FileRequest{ + File: f, + Purpose: "fine-tune", + }, nil +} + +// File represents an OpenAPI file. type File struct { ID string `json:"id"` Object objects.Object `json:"object"` @@ -31,9 +44,8 @@ type File struct { Purpose string `json:"purpose"` } -// TODO: FileRequest should take a file.File. -// CreateFile ... -func (c *Client) CreateFile(ctx context.Context, fr *FileRequest) (*File, error) { +// UploadFile ... +func (c *Client) UploadFile(ctx context.Context, fr *FileRequest) (*File, error) { var b, err = c.postFile(ctx, fr) if err != nil { return nil, err diff --git a/fine_tunes.go b/fine_tunes.go index 637d769..bf26f7b 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -10,7 +10,7 @@ import ( "github.com/fabiustech/openai/routes" ) -// FineTuneRequest ... +// FineTuneRequest contains all relevant fields for requests to the fine-tunes endpoints. type FineTuneRequest struct { // TrainingFile specifies the ID of an uploaded file that contains training data. See upload file for how to upload // a file. @@ -91,6 +91,7 @@ type FineTuneRequest struct { Suffix string `json:"suffix,omitempty"` } +// Event represents an event related to a fine-tune request. type Event struct { Object objects.Object `json:"object"` CreatedAt uint64 `json:"created_at"` @@ -98,6 +99,7 @@ type Event struct { Message string `json:"message"` } +// FineTuneResponse is the response from fine-tunes endpoints. type FineTuneResponse struct { ID string `json:"id"` Object objects.Object `json:"object"` @@ -126,6 +128,7 @@ type FineTuneResponse struct { UpdatedAt uint64 `json:"updated_at"` } +// FineTuneDeletionResponse is the response from the fine-tunes/delete endpoint. type FineTuneDeletionResponse struct { ID string `json:"id"` Object objects.Object `json:"object"` @@ -189,7 +192,7 @@ func (c *Client) CancelFineTune(ctx context.Context, id string) (*FineTuneRespon return f, nil } -// TODO: Support streaming (maybe different method). +// TODO: Support streaming (in a different method). func (c *Client) ListFineTuneEvents(ctx context.Context, id string) (*List[*Event], error) { var b, err = c.get(ctx, path.Join(routes.FineTunes, id, "events")) if err != nil { diff --git a/list.go b/list.go index 5f5eda6..5184927 100644 --- a/list.go +++ b/list.go @@ -4,7 +4,10 @@ import ( "github.com/fabiustech/openai/objects" ) +// List represents a generic form of list of objects returned from many get endpoints. type List[T any] struct { + // Object specifies the object type (e.g. Model). Object objects.Object `json:"object"` - Data []T `json:"data"` + // Data contains the list of objects. + Data []T `json:"data"` } diff --git a/moderation.go b/moderation.go index 91568bb..3f134bf 100644 --- a/moderation.go +++ b/moderation.go @@ -9,7 +9,7 @@ import ( "github.com/fabiustech/openai/routes" ) -// ModerationRequest represents a request structure for moderation API. +// ModerationRequest contains all relevant fields for requests to the moderations endpoint. type ModerationRequest struct { // Input is the input text to classify. Input string `json:"input,omitempty"` From a72baf296782d3967c87666ffb4dd1d5201b131f Mon Sep 17 00:00:00 2001 From: Andy Day Date: Wed, 18 Jan 2023 13:48:29 -0800 Subject: [PATCH 22/30] clean --- client.go | 39 ++++++++++++---------- client_test.go | 75 +++++++++++++++++++++++-------------------- completions.go | 4 +-- edits.go | 4 +-- embeddings.go | 18 ++++------- engines.go | 9 +++--- files.go | 45 +++++++++++++------------- fine_tunes.go | 20 ++++++++---- images.go | 18 +++++------ models/moderations.go | 2 +- moderation.go | 6 ++-- 11 files changed, 126 insertions(+), 114 deletions(-) diff --git a/client.go b/client.go index c9c8116..5a0a9d6 100644 --- a/client.go +++ b/client.go @@ -26,6 +26,7 @@ type Client struct { orgID *string // scheme and host are only used for testing. + // TODO: Figure out a better approach. scheme, host string } @@ -48,6 +49,22 @@ func NewClientWithOrg(token, org string) *Client { } } +func (c *Client) newRequest(ctx context.Context, method string, url string, body io.Reader) (*http.Request, error) { + var req, err = http.NewRequestWithContext(ctx, method, url, body) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json; charset=utf-8") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.token)) + + if c.orgID != nil { + req.Header.Set("OpenAI-Organization", *c.orgID) + } + + return req, nil +} + func (c *Client) post(ctx context.Context, path string, payload any) ([]byte, error) { var b, err = json.Marshal(payload) if err != nil { @@ -55,17 +72,12 @@ func (c *Client) post(ctx context.Context, path string, payload any) ([]byte, er } var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", c.reqURL(path), bytes.NewBuffer(b)) + req, err = c.newRequest(ctx, "POST", c.reqURL(path), bytes.NewBuffer(b)) if err != nil { return nil, err } - req.Header.Set("Content-Type", "application/json; charset=utf-8") - if c.orgID != nil { - req.Header.Set("OpenAI-Organization", *c.orgID) - } - var resp *http.Response resp, err = http.DefaultClient.Do(req) if err != nil { @@ -102,17 +114,13 @@ func (c *Client) postFile(ctx context.Context, fr *FileRequest) ([]byte, error) } var req *http.Request - req, err = http.NewRequestWithContext(ctx, "POST", c.reqURL(routes.Files), &b) + req, err = c.newRequest(ctx, "POST", c.reqURL(routes.Files), &b) if err != nil { return nil, err } req.Header.Set("Content-Type", w.FormDataContentType()) - if c.orgID != nil { - req.Header.Set("OpenAI-Organization", *c.orgID) - } - var resp *http.Response resp, err = http.DefaultClient.Do(req) if err != nil { @@ -128,15 +136,11 @@ func (c *Client) postFile(ctx context.Context, fr *FileRequest) ([]byte, error) } func (c *Client) get(ctx context.Context, path string) ([]byte, error) { - var req, err = http.NewRequestWithContext(ctx, "POST", c.reqURL(path), nil) + var req, err = c.newRequest(ctx, "POST", c.reqURL(path), nil) if err != nil { return nil, err } - if c.orgID != nil { - req.Header.Set("OpenAI-Organization", *c.orgID) - } - var resp *http.Response resp, err = http.DefaultClient.Do(req) if err != nil { @@ -152,7 +156,7 @@ func (c *Client) get(ctx context.Context, path string) ([]byte, error) { } func (c *Client) delete(ctx context.Context, path string) ([]byte, error) { - var req, err = http.NewRequestWithContext(ctx, "DELETE", c.reqURL(path), nil) + var req, err = c.newRequest(ctx, "DELETE", c.reqURL(path), nil) if err != nil { return nil, err } @@ -177,6 +181,7 @@ func (c *Client) reqURL(route string) string { Host: c.host, Path: path.Join(basePath, route), } + return u.String() } diff --git a/client_test.go b/client_test.go index e47d550..10bbff0 100644 --- a/client_test.go +++ b/client_test.go @@ -9,6 +9,7 @@ import ( "log" "net/http" "net/http/httptest" + "net/url" "os" "strconv" "strings" @@ -22,7 +23,8 @@ import ( ) /* -This test suite has been ported from the original repo: +This test suite has been ported from the original repo: https://github.com/sashabaranov/go-gpt3. +It is incomplete, and it's usefulness is questionable. TODO: Cover all endpoints. */ @@ -56,9 +58,9 @@ func TestAPI(t *testing.T) { } if len(fl.Data) > 0 { - _, err = c.GetFile(ctx, fl.Data[0].ID) + _, err = c.RetrieveFile(ctx, fl.Data[0].ID) if err != nil { - t.Fatalf("GetFile error: %v", err) + t.Fatalf("RetrieveFile error: %v", err) } } @@ -74,12 +76,17 @@ func TestAPI(t *testing.T) { } } -func newTestClient(host string) *Client { +func newTestClient(u string) (*Client, error) { + var h, err = url.Parse(u) + if err != nil { + return nil, err + } + return &Client{ token: testToken, - host: host, + host: h.Host, scheme: "http", - } + }, nil } // TestCompletions Tests the completions endpoint of the API using the mocked server. @@ -88,10 +95,9 @@ func TestCompletions(t *testing.T) { ts.Start() defer ts.Close() - var client = newTestClient(ts.URL) - ctx := context.Background() + var client, _ = newTestClient(ts.URL) - var _, err = client.CreateCompletion(ctx, &CompletionRequest{ + var _, err = client.CreateCompletion(context.Background(), &CompletionRequest{ Prompt: "Lorem ipsum", Model: models.TextDavinci003, MaxTokens: 5, @@ -107,11 +113,10 @@ func TestEdits(t *testing.T) { ts.Start() defer ts.Close() - var client = newTestClient(ts.URL) - ctx := context.Background() + var client, _ = newTestClient(ts.URL) var n = 3 - var resp, err = client.CreateEdit(ctx, &EditsRequest{ + var resp, err = client.CreateEdit(context.Background(), &EditsRequest{ Model: models.TextDavinciEdit001, Input: "Lorem ipsum dolor sit amet, consectetur adipiscing elit, " + "sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim" + @@ -271,24 +276,31 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { // handleImageEndpoint Handles the images endpoint by the test server. func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { - var err error - var resBytes []byte - - // imagess only accepts POST requests + // Images only accepts POST requests. if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var imageReq *CreateImageRequest - if imageReq, err = getImageBody(r); err != nil { + var ir, err = getImageBody(r) + if err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := &ImageResponse{ + + var resp = &ImageResponse{ Created: uint64(time.Now().Unix()), } - for i := 0; i < imageReq.N; i++ { + + // Handle default values. + if ir.N == 0 { + ir.N = 1 + } + if ir.ResponseFormat == images.FormatInvalid { + ir.ResponseFormat = images.FormatURL + } + + for i := 0; i < ir.N; i++ { var imageData = &ImageData{} - switch imageReq.ResponseFormat { + switch ir.ResponseFormat { case images.FormatURL: imageData.URL = params.Optional("https://example.com/image.png") case images.FormatB64JSON: @@ -298,10 +310,11 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { http.Error(w, "invalid response format", http.StatusBadRequest) return } - res.Data = append(res.Data, imageData) + resp.Data = append(resp.Data, imageData) } - resBytes, _ = json.Marshal(res) - fmt.Fprintln(w, string(resBytes)) + + var b, _ = json.Marshal(resp) + w.Write(b) } // getCompletionBody Returns the body of the request to create a completion. @@ -331,6 +344,7 @@ func getImageBody(r *http.Request) (*CreateImageRequest, error) { if err != nil { return nil, err } + return image, nil } @@ -344,19 +358,12 @@ func numTokens(s string) int { } func TestImages(t *testing.T) { - // create the test server - var err error - ts := OpenAITestServer() + var ts = OpenAITestServer() ts.Start() defer ts.Close() - client := NewClient(testToken) - ctx := context.Background() - // client.BaseURL = ts.URL + "/v1" - - req := &CreateImageRequest{} - req.Prompt = "Lorem ipsum" - _, err = client.CreateImage(ctx, req) + var client, _ = newTestClient(ts.URL) + var _, err = client.CreateImage(context.Background(), &CreateImageRequest{Prompt: "Lorem ipsum"}) if err != nil { t.Fatalf("CreateImage error: %v", err) } diff --git a/completions.go b/completions.go index 670799d..da1f00e 100644 --- a/completions.go +++ b/completions.go @@ -117,14 +117,14 @@ type CompletionResponse struct { Usage *Usage `json:"usage"` } -// CreateCompletion ... +// CreateCompletion creates a completion for the provided prompt and parameters. func (c *Client) CreateCompletion(ctx context.Context, cr *CompletionRequest) (*CompletionResponse, error) { var b, err = c.post(ctx, routes.Completions, cr) if err != nil { return nil, err } - var resp *CompletionResponse + var resp = &CompletionResponse{} if err = json.Unmarshal(b, resp); err != nil { return nil, err } diff --git a/edits.go b/edits.go index 7749248..0b9dfe0 100644 --- a/edits.go +++ b/edits.go @@ -47,11 +47,11 @@ type EditsResponse struct { Choices []*EditsChoice `json:"choices"` } -// CreateEdit ... +// CreateEdit creates a new edit for the provided input, instruction, and parameters. func (c *Client) CreateEdit(ctx context.Context, er *EditsRequest) (*EditsResponse, error) { var b, err = c.post(ctx, routes.Edits, er) - var resp *EditsResponse + var resp = &EditsResponse{} if err = json.Unmarshal(b, resp); err != nil { return nil, err } diff --git a/embeddings.go b/embeddings.go index c7ee87a..d4320e4 100644 --- a/embeddings.go +++ b/embeddings.go @@ -30,29 +30,23 @@ type EmbeddingResponse struct { // EmbeddingRequest contains all relevant fields for requests to the embeddings endpoint. type EmbeddingRequest struct { - // Input is a slice of strings for which you want to generate an Embedding vector. - // Each input must not exceed 2048 tokens in length. - // OpenAPI suggests replacing newlines (\n) in your input with a single space, as they - // have observed inferior results when newlines are present. - // E.g. - // "The food was delicious and the waiter..." + // Input represents input text to get embeddings for, encoded as a strings. To get embeddings for multiple inputs in + //a single request, pass a slice of length > 1. Each input string must not exceed 8192 tokens in length. Input []string `json:"input"` - // ID of the model to use. You can use the List models API to see all of your available models, - // or see our Model overview for descriptions of them. + // Model is the ID of the model to use. Model models.Embedding `json:"model"` - // A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. + // User is a unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. User string `json:"user"` } -// CreateEmbeddings returns an EmbeddingResponse which will contain an Embedding for every item in |request.Input|. -// https://beta.openai.com/docs/api-reference/embeddings/create +// CreateEmbeddings creates an embedding vector representing the input text. func (c *Client) CreateEmbeddings(ctx context.Context, request *EmbeddingRequest) (*EmbeddingResponse, error) { var b, err = c.post(ctx, routes.Embeddings, request) if err != nil { return nil, err } - var resp *EmbeddingResponse + var resp = &EmbeddingResponse{} if err = json.Unmarshal(b, resp); err != nil { return nil, err } diff --git a/engines.go b/engines.go index 2f29434..1f93cff 100644 --- a/engines.go +++ b/engines.go @@ -16,7 +16,7 @@ type Engine struct { Ready bool `json:"ready"` } -// ListEngines Lists the currently available engines, and provides basic +// ListEngines lists the currently available engines, and provides basic // information about each option such as the owner and availability. // // Deprecated: Please use their replacement, Models, instead. @@ -27,7 +27,7 @@ func (c *Client) ListEngines(ctx context.Context) (*List[*Engine], error) { return nil, err } - var el *List[*Engine] + var el = &List[*Engine]{} if err = json.Unmarshal(b, el); err != nil { return nil, err } @@ -35,8 +35,7 @@ func (c *Client) ListEngines(ctx context.Context) (*List[*Engine], error) { return el, nil } -// GetEngine Retrieves an engine instance, providing basic information about -// the engine such as the owner and availability. +// GetEngine retrieves a model instance, providing basic information about it such as the owner and availability. // // Deprecated: Please use their replacement, Models, instead. // https://beta.openai.com/docs/api-reference/models @@ -46,7 +45,7 @@ func (c *Client) GetEngine(ctx context.Context, id string) (*Engine, error) { return nil, err } - var e *Engine + var e = &Engine{} if err = json.Unmarshal(b, e); err != nil { return nil, err } diff --git a/files.go b/files.go index 9609d07..6dbf411 100644 --- a/files.go +++ b/files.go @@ -10,7 +10,7 @@ import ( "github.com/fabiustech/openai/routes" ) -// FileRequest ... +// FileRequest contains all relevant data for upload requests to the files endpoint. type FileRequest struct { // File is the JSON Lines file to be uploaded. If the purpose is set to "fine-tune", each line is a JSON record // with "prompt" and "completion" fields representing your training examples: @@ -44,14 +44,30 @@ type File struct { Purpose string `json:"purpose"` } -// UploadFile ... +// ListFiles returns a list of files that belong to the user's organization. +func (c *Client) ListFiles(ctx context.Context) (*List[*File], error) { + var b, err = c.get(ctx, routes.Files) + if err != nil { + return nil, err + } + + var fl = &List[*File]{} + if err = json.Unmarshal(b, fl); err != nil { + return nil, err + } + + return fl, nil +} + +// UploadFile uploads a file that contains document(s) to be used across various endpoints/features. Currently, the size +// of all the files uploaded by one organization can be up to 1 GB. func (c *Client) UploadFile(ctx context.Context, fr *FileRequest) (*File, error) { var b, err = c.postFile(ctx, fr) if err != nil { return nil, err } - var f *File + var f = &File{} if err = json.Unmarshal(b, f); err != nil { return nil, err } @@ -59,36 +75,21 @@ func (c *Client) UploadFile(ctx context.Context, fr *FileRequest) (*File, error) return f, nil } -// DeleteFile ... +// DeleteFile deletes a file. func (c *Client) DeleteFile(ctx context.Context, id string) error { var _, err = c.delete(ctx, path.Join(routes.Files, id)) return err } -// ListFiles ... -func (c *Client) ListFiles(ctx context.Context) (*List[*File], error) { - var b, err = c.get(ctx, routes.Files) - if err != nil { - return nil, err - } - - var fl *List[*File] - if err = json.Unmarshal(b, fl); err != nil { - return nil, err - } - - return fl, nil -} - -// GetFile ... -func (c *Client) GetFile(ctx context.Context, id string) (*File, error) { +// RetrieveFile returns information about a specific file. +func (c *Client) RetrieveFile(ctx context.Context, id string) (*File, error) { var b, err = c.get(ctx, path.Join(routes.Files, id)) if err != nil { return nil, err } - var f *File + var f = &File{} if err = json.Unmarshal(b, f); err != nil { return nil, err } diff --git a/fine_tunes.go b/fine_tunes.go index bf26f7b..56a1e3a 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -135,14 +135,15 @@ type FineTuneDeletionResponse struct { Deleted bool `json:"deleted"` } -// CreateFineTune ... +// CreateFineTune creates a job that fine-tunes a specified model from a given dataset. *FineTuneResponse includes +// details of the enqueued job including job status and the name of the fine-tuned models once complete. func (c *Client) CreateFineTune(ctx context.Context, ftr *FineTuneRequest) (*FineTuneResponse, error) { var b, err = c.post(ctx, routes.FineTunes, ftr) if err != nil { return nil, err } - var f *FineTuneResponse + var f = &FineTuneResponse{} if err = json.Unmarshal(b, f); err != nil { return nil, err } @@ -150,13 +151,14 @@ func (c *Client) CreateFineTune(ctx context.Context, ftr *FineTuneRequest) (*Fin return f, nil } +// ListFineTunes lists your organization's fine-tuning jobs. func (c *Client) ListFineTunes(ctx context.Context) (*List[*FineTuneResponse], error) { var b, err = c.get(ctx, routes.FineTunes) if err != nil { return nil, err } - var l *List[*FineTuneResponse] + var l = &List[*FineTuneResponse]{} if err = json.Unmarshal(b, l); err != nil { return nil, err } @@ -164,13 +166,14 @@ func (c *Client) ListFineTunes(ctx context.Context) (*List[*FineTuneResponse], e return l, nil } +// RetrieveFineTune gets info about the fine-tune job. func (c *Client) RetrieveFineTune(ctx context.Context, id string) (*FineTuneResponse, error) { var b, err = c.get(ctx, path.Join(routes.FineTunes, id)) if err != nil { return nil, err } - var f *FineTuneResponse + var f = &FineTuneResponse{} if err = json.Unmarshal(b, f); err != nil { return nil, err } @@ -178,13 +181,14 @@ func (c *Client) RetrieveFineTune(ctx context.Context, id string) (*FineTuneResp return f, nil } +// CancelFineTune immediately cancels a fine-tune job. func (c *Client) CancelFineTune(ctx context.Context, id string) (*FineTuneResponse, error) { var b, err = c.post(ctx, path.Join(routes.FineTunes, id, "cancel"), nil) if err != nil { return nil, err } - var f *FineTuneResponse + var f = &FineTuneResponse{} if err = json.Unmarshal(b, f); err != nil { return nil, err } @@ -192,6 +196,7 @@ func (c *Client) CancelFineTune(ctx context.Context, id string) (*FineTuneRespon return f, nil } +// ListFineTuneEvents returns fine-grained status updates for a fine-tune job. // TODO: Support streaming (in a different method). func (c *Client) ListFineTuneEvents(ctx context.Context, id string) (*List[*Event], error) { var b, err = c.get(ctx, path.Join(routes.FineTunes, id, "events")) @@ -199,7 +204,7 @@ func (c *Client) ListFineTuneEvents(ctx context.Context, id string) (*List[*Even return nil, err } - var l *List[*Event] + var l = &List[*Event]{} if err = json.Unmarshal(b, l); err != nil { return nil, err } @@ -207,13 +212,14 @@ func (c *Client) ListFineTuneEvents(ctx context.Context, id string) (*List[*Even return l, nil } +// DeleteFineTune delete a fine-tuned model. You must have the Owner role in your organization. func (c *Client) DeleteFineTune(ctx context.Context, id string) (*FineTuneDeletionResponse, error) { var b, err = c.delete(ctx, path.Join(routes.FineTunes, id)) if err != nil { return nil, err } - var f *FineTuneDeletionResponse + var f = &FineTuneDeletionResponse{} if err = json.Unmarshal(b, f); err != nil { return nil, err } diff --git a/images.go b/images.go index 603b816..0f46817 100644 --- a/images.go +++ b/images.go @@ -12,7 +12,7 @@ import ( type CreateImageRequest struct { // Prompt is a text description of the desired image(s). The maximum length is 1000 characters. Prompt string `json:"prompt"` - *ImageRequestFields + ImageRequestFields } // EditImageRequest contains all relevant fields for requests to the images/edits endpoint. @@ -25,14 +25,14 @@ type EditImageRequest struct { Mask string `json:"mask,omitempty"` // Prompt is a text description of the desired image(s). The maximum length is 1000 characters. Prompt string `json:"prompt"` - *ImageRequestFields + ImageRequestFields } // VariationImageRequest contains all relevant fields for requests to the images/variations endpoint. type VariationImageRequest struct { // Image is the image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square. Image string `json:"image"` - *ImageRequestFields + ImageRequestFields } // ImageRequestFields contains the common fields for all images endpoints. @@ -66,14 +66,14 @@ type ImageData struct { B64JSON *string `json:"b64_json,omitempty"` } -// CreateImage ... +// CreateImage creates an image (or images) given a prompt. func (c *Client) CreateImage(ctx context.Context, ir *CreateImageRequest) (*ImageResponse, error) { var b, err = c.post(ctx, routes.ImageGenerations, ir) if err != nil { return nil, err } - var resp *ImageResponse + var resp = &ImageResponse{} if err = json.Unmarshal(b, resp); err != nil { return nil, err } @@ -81,14 +81,14 @@ func (c *Client) CreateImage(ctx context.Context, ir *CreateImageRequest) (*Imag return resp, nil } -// EditImage ... +// EditImage creates an edited or extended image (or images) given an original image and a prompt. func (c *Client) EditImage(ctx context.Context, eir *EditImageRequest) (*ImageResponse, error) { var b, err = c.post(ctx, routes.ImageEdits, eir) if err != nil { return nil, err } - var resp *ImageResponse + var resp = &ImageResponse{} if err = json.Unmarshal(b, resp); err != nil { return nil, err } @@ -96,14 +96,14 @@ func (c *Client) EditImage(ctx context.Context, eir *EditImageRequest) (*ImageRe return resp, nil } -// ImageVariation ... +// ImageVariation creates a variation (or variations) of a given image. func (c *Client) ImageVariation(ctx context.Context, vir *VariationImageRequest) (*ImageResponse, error) { var b, err = c.post(ctx, routes.ImageVariations, vir) if err != nil { return nil, err } - var resp *ImageResponse + var resp = &ImageResponse{} if err = json.Unmarshal(b, resp); err != nil { return nil, err } diff --git a/models/moderations.go b/models/moderations.go index fa65c7d..0d6ff33 100644 --- a/models/moderations.go +++ b/models/moderations.go @@ -1,6 +1,6 @@ package models -// Moderation represents all models available for use with the Moderations endpoint. +// Moderation represents all models available for use with the CreateModeration endpoint. type Moderation int const ( diff --git a/moderation.go b/moderation.go index 3f134bf..89285ea 100644 --- a/moderation.go +++ b/moderation.go @@ -54,14 +54,14 @@ type ModerationResponse struct { Results []Result `json:"results"` } -// Moderations ... -func (c *Client) Moderations(ctx context.Context, mr *ModerationRequest) (*ModerationResponse, error) { +// CreateModeration classifies if text violates OpenAI's Content Policy. +func (c *Client) CreateModeration(ctx context.Context, mr *ModerationRequest) (*ModerationResponse, error) { var b, err = c.post(ctx, routes.Moderations, mr) if err != nil { return nil, err } - var resp *ModerationResponse + var resp = &ModerationResponse{} if err = json.Unmarshal(b, resp); err != nil { return nil, err } From 0a0a23ed0e9e6081e62b8fe42d051a2afc3b78cd Mon Sep 17 00:00:00 2001 From: Andy Day Date: Wed, 18 Jan 2023 13:51:16 -0800 Subject: [PATCH 23/30] cool --- client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client.go b/client.go index 5a0a9d6..d844532 100644 --- a/client.go +++ b/client.go @@ -136,7 +136,7 @@ func (c *Client) postFile(ctx context.Context, fr *FileRequest) ([]byte, error) } func (c *Client) get(ctx context.Context, path string) ([]byte, error) { - var req, err = c.newRequest(ctx, "POST", c.reqURL(path), nil) + var req, err = c.newRequest(ctx, "GET", c.reqURL(path), nil) if err != nil { return nil, err } From 1ad0480f483f7f4a2736458771490a0a4857c028 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Wed, 18 Jan 2023 14:07:52 -0800 Subject: [PATCH 24/30] sanity check failures --- client.go | 2 +- client_test.go | 18 +----------------- completions.go | 16 ++++++++-------- edits.go | 9 ++++++--- embeddings.go | 2 +- 5 files changed, 17 insertions(+), 30 deletions(-) diff --git a/client.go b/client.go index d844532..ad30be6 100644 --- a/client.go +++ b/client.go @@ -191,7 +191,7 @@ func interpretResponse(resp *http.Response) error { if err != nil { return fmt.Errorf("error, status code: %d", resp.StatusCode) } - var er *errorResponse + var er = &errorResponse{} if err = json.Unmarshal(b, er); err != nil || er.Error == nil { return fmt.Errorf("error, status code: %d, msg: %s", resp.StatusCode, string(b)) } diff --git a/client_test.go b/client_test.go index 10bbff0..837790a 100644 --- a/client_test.go +++ b/client_test.go @@ -135,22 +135,6 @@ func TestEdits(t *testing.T) { func TestEmbedding(t *testing.T) { embeddedModels := []models.Embedding{ - models.AdaSimilarity, - models.BabbageSimilarity, - models.CurieSimilarity, - models.DavinciSimilarity, - models.AdaSearchDocument, - models.AdaSearchQuery, - models.BabbageSearchDocument, - models.BabbageSearchQuery, - models.CurieSearchDocument, - models.CurieSearchQuery, - models.DavinciSearchDocument, - models.DavinciSearchQuery, - models.AdaCodeSearchCode, - models.AdaCodeSearchText, - models.BabbageCodeSearchCode, - models.BabbageCodeSearchText, models.AdaEmbeddingV2, } for _, model := range embeddedModels { @@ -314,7 +298,7 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { } var b, _ = json.Marshal(resp) - w.Write(b) + _, _ = w.Write(b) } // getCompletionBody Returns the body of the request to create a completion. diff --git a/completions.go b/completions.go index da1f00e..e946418 100644 --- a/completions.go +++ b/completions.go @@ -14,22 +14,22 @@ type CompletionRequest struct { // Model specifies the ID of the model to use. // See more here: https://beta.openai.com/docs/models/overview Model models.Completion `json:"model"` - // Prompt specifies the prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, - // or array of token arrays. Note that <|endoftext|> is the document separator that the model sees during + // Prompt specifies the prompt(s) to generate completions for, encoded as a string, array of strings, array of + // tokens, or array of token arrays. Note that <|endoftext|> is the document separator that the model sees during // training, so if a prompt is not specified the model will generate as if from the beginning of a new document. // Defaults to <|endoftext|>. Prompt string `json:"prompt,omitempty"` // Suffix specifies the suffix that comes after a completion of inserted text. // Defaults to null. Suffix string `json:"suffix,omitempty"` - // MaxTokens specifies the maximum number of tokens to generate in the completion. The token count of your prompt plus - // max_tokens cannot exceed the model's context length. Most models have a context length of 2048 tokens (except - // for the newest models, which support 4096). + // MaxTokens specifies the maximum number of tokens to generate in the completion. The token count of your prompt + // plus max_tokens cannot exceed the model's context length. Most models have a context length of 2048 tokens + // (except for the newest models, which support 4096). // Defaults to 16. MaxTokens int `json:"max_tokens,omitempty"` - // Temperature specifies what sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative - // applications, and 0 (argmax sampling) for ones with a well-defined answer. OpenAI generally recommends altering - // this or top_p but not both. + // Temperature specifies what sampling temperature to use. Higher values means the model will take more risks. Try + // 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. OpenAI generally + //recommends altering this or top_p but not both. // More on sampling temperature: https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277 // Defaults to 1. Temperature *float32 `json:"temperature,omitempty"` diff --git a/edits.go b/edits.go index 0b9dfe0..d573867 100644 --- a/edits.go +++ b/edits.go @@ -20,9 +20,9 @@ type EditsRequest struct { // N specifies how many edits to generate for the input and instruction. // Defaults to 1. N int `json:"n,omitempty"` - // Temperature specifies what sampling temperature to use. Higher values means the model will take more risks. Try 0.9 for more creative - // applications, and 0 (argmax sampling) for ones with a well-defined answer. OpenAI generally recommends altering - // this or top_p but not both. + // Temperature specifies what sampling temperature to use. Higher values means the model will take more risks. + // Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. OpenAI + // generally recommends altering this or top_p but not both. // More on sampling temperature: https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277 // Defaults to 1. Temperature *float32 `json:"temperature,omitempty"` @@ -50,6 +50,9 @@ type EditsResponse struct { // CreateEdit creates a new edit for the provided input, instruction, and parameters. func (c *Client) CreateEdit(ctx context.Context, er *EditsRequest) (*EditsResponse, error) { var b, err = c.post(ctx, routes.Edits, er) + if err != nil { + return nil, err + } var resp = &EditsResponse{} if err = json.Unmarshal(b, resp); err != nil { diff --git a/embeddings.go b/embeddings.go index d4320e4..e7d095f 100644 --- a/embeddings.go +++ b/embeddings.go @@ -31,7 +31,7 @@ type EmbeddingResponse struct { // EmbeddingRequest contains all relevant fields for requests to the embeddings endpoint. type EmbeddingRequest struct { // Input represents input text to get embeddings for, encoded as a strings. To get embeddings for multiple inputs in - //a single request, pass a slice of length > 1. Each input string must not exceed 8192 tokens in length. + // a single request, pass a slice of length > 1. Each input string must not exceed 8192 tokens in length. Input []string `json:"input"` // Model is the ID of the model to use. Model models.Embedding `json:"model"` From 4a817161d1142a3b232c2d5d53a42bb348136823 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Wed, 18 Jan 2023 14:20:52 -0800 Subject: [PATCH 25/30] try --- .github/workflows/pr.yml | 7 ++++--- .golangci.yml | 2 +- client_test.go | 6 ++---- completions.go | 2 +- 4 files changed, 8 insertions(+), 9 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index ee77f3c..6780196 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -11,7 +11,7 @@ jobs: - name: Setup Go uses: actions/setup-go@v2 with: - go-version: '1.18' + go-version: '1.19' - name: Run vet run: | go vet . @@ -19,6 +19,7 @@ jobs: uses: golangci/golangci-lint-action@v3 with: version: latest + # # Run testing on the code - # - name: Run testing - # run: cd test && go test -v + - name: Run testing + run: cd test && go test -v diff --git a/.golangci.yml b/.golangci.yml index bdf66a8..ea4ba05 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -205,7 +205,7 @@ linters: - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - stylecheck # Stylecheck is a replacement for golint - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - - testpackage # linter that makes you use a separate _test package + # - testpackage # linter that makes you use a separate _test package - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters diff --git a/client_test.go b/client_test.go index 837790a..8672186 100644 --- a/client_test.go +++ b/client_test.go @@ -278,14 +278,12 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { if ir.N == 0 { ir.N = 1 } - if ir.ResponseFormat == images.FormatInvalid { - ir.ResponseFormat = images.FormatURL - } for i := 0; i < ir.N; i++ { var imageData = &ImageData{} switch ir.ResponseFormat { - case images.FormatURL: + // Invalid is the go default value, and URL is the default API behavior. + case images.FormatURL, images.FormatInvalid: imageData.URL = params.Optional("https://example.com/image.png") case images.FormatB64JSON: // This decodes to "{}" in base64. diff --git a/completions.go b/completions.go index e946418..0b09a11 100644 --- a/completions.go +++ b/completions.go @@ -29,7 +29,7 @@ type CompletionRequest struct { MaxTokens int `json:"max_tokens,omitempty"` // Temperature specifies what sampling temperature to use. Higher values means the model will take more risks. Try // 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. OpenAI generally - //recommends altering this or top_p but not both. + // recommends altering this or top_p but not both. // More on sampling temperature: https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277 // Defaults to 1. Temperature *float32 `json:"temperature,omitempty"` From 9a8497757f70f174cbe7dc118e4ce4445eec49fd Mon Sep 17 00:00:00 2001 From: Andy Day Date: Wed, 18 Jan 2023 14:22:16 -0800 Subject: [PATCH 26/30] fix workflow --- .github/workflows/pr.yml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 6780196..b9b3fbd 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -19,7 +19,5 @@ jobs: uses: golangci/golangci-lint-action@v3 with: version: latest - - # # Run testing on the code - name: Run testing - run: cd test && go test -v + run: go test -v From 46ffeb359624c3ed1fed9d5dc9ae61ab1ffca9b0 Mon Sep 17 00:00:00 2001 From: Andy Day Date: Wed, 18 Jan 2023 20:54:51 -0800 Subject: [PATCH 27/30] cleanup --- completions.go | 114 ++++++++++++++++++++++++++++++++++++++++++- fine_tunes.go | 12 ++--- images.go | 35 ++++++++++--- models/fine_tunes.go | 9 ++++ 4 files changed, 156 insertions(+), 14 deletions(-) diff --git a/completions.go b/completions.go index 0b09a11..0a0c60d 100644 --- a/completions.go +++ b/completions.go @@ -91,6 +91,91 @@ type CompletionRequest struct { User string `json:"user,omitempty"` } +// FineTunedCompletionRequest contains all relevant fields for requests to the completions endpoint, +// using a fine-tuned model. +// +// Note: This seems completely redundant, and may change (or be removed) if a better / simpler solution arrises. +type FineTunedCompletionRequest struct { + // Model specifies the ID of the model to use. + // See more here: https://beta.openai.com/docs/models/overview + Model models.FineTunedModel `json:"model"` + // Prompt specifies the prompt(s) to generate completions for, encoded as a string, array of strings, array of + // tokens, or array of token arrays. Note that <|endoftext|> is the document separator that the model sees during + // training, so if a prompt is not specified the model will generate as if from the beginning of a new document. + // Defaults to <|endoftext|>. + Prompt string `json:"prompt,omitempty"` + // Suffix specifies the suffix that comes after a completion of inserted text. + // Defaults to null. + Suffix string `json:"suffix,omitempty"` + // MaxTokens specifies the maximum number of tokens to generate in the completion. The token count of your prompt + // plus max_tokens cannot exceed the model's context length. Most models have a context length of 2048 tokens + // (except for the newest models, which support 4096). + // Defaults to 16. + MaxTokens int `json:"max_tokens,omitempty"` + // Temperature specifies what sampling temperature to use. Higher values means the model will take more risks. Try + // 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. OpenAI generally + // recommends altering this or top_p but not both. + // More on sampling temperature: https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277 + // Defaults to 1. + Temperature *float32 `json:"temperature,omitempty"` + // TopP specifies an alternative to sampling with temperature, called nucleus sampling, where the model considers + // the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + // probability mass are considered. OpenAI generally recommends altering this or temperature but not both. + // Defaults to 1. + TopP *float32 `json:"top_p,omitempty"` + // N specifies how many completions to generate for each prompt. + // Note: Because this parameter generates many completions, it can quickly consume your token quota. Use carefully + // and ensure that you have reasonable settings for max_tokens and stop. + // Defaults to 1. + N int `json:"n,omitempty"` + // Steam specifies Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent + // events as they become available, with the stream terminated by a data: [DONE] message. + // Defaults to false. + Stream bool `json:"stream,omitempty"` + // LogProbs specifies to include the log probabilities on the logprobs most likely tokens, as well the chosen + // tokens. For example, if logprobs is 5, the API will return a list of the 5 most likely tokens. The API will + // always return the logprob of the sampled token, so there may be up to logprobs+1 elements in the response. + // The maximum value for logprobs is 5. + // Defaults to null. + LogProbs *int `json:"logprobs,omitempty"` + // Echo specifies to echo back the prompt in addition to the completion. + // Defaults to false. + Echo bool `json:"echo,omitempty"` + // Stop specifies up to 4 sequences where the API will stop generating further tokens. The returned text will not + // contain the stop sequence. + Stop []string `json:"stop,omitempty"` + // PresencePenalty can be a number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + // appear in the text so far, increasing the model's likelihood to talk about new topics. + // Defaults to 0. + PresencePenalty float32 `json:"presence_penalty,omitempty"` + // FrequencyPenalty can be a number between -2.0 and 2.0. Positive values penalize new tokens based on their + // existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + // Defaults to 0. + FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` + // Generates best_of completions server-side and returns the "best" (the one with the highest log probability per + // token). Results cannot be streamed. When used with n, best_of controls the number of candidate completions and n + // specifies how many to return – best_of must be greater than n. Note: Because this parameter generates many + // completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings + // for max_tokens and stop. + // Defaults to 1. + BestOf *int `json:"best_of,omitempty"` + // LogitBias modifies the likelihood of specified tokens appearing in the completion. Accepts a json object that + // maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. + // Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will + // vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like + // -100 or 100 should result in a ban or exclusive selection of the relevant token. + // As an example, you can pass {"50256": -100} to prevent the <|endoftext|> token from being generated. + // + // You can use this tokenizer tool to convert text to token IDs: + // https://beta.openai.com/tokenizer + // + // Defaults to null. + LogitBias map[string]int `json:"logit_bias,omitempty"` + // User is a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + // See more here: https://beta.openai.com/docs/guides/safety-best-practices/end-user-ids + User string `json:"user,omitempty"` +} + // CompletionChoice represents one of possible completions. type CompletionChoice struct { Text string `json:"text"` @@ -107,7 +192,7 @@ type LogprobResult struct { TextOffset []int `json:"text_offset"` } -// CompletionResponse represents a response structure for completion API. +// CompletionResponse is the response from the completions endpoint. type CompletionResponse struct { ID string `json:"id"` Object objects.Object `json:"object"` @@ -117,6 +202,18 @@ type CompletionResponse struct { Usage *Usage `json:"usage"` } +// FineTunedCompletionResponse is the response from the completions endpoint. +// +// Note: This seems completely redundant, and may change (or be removed) if a better / simpler solution arrises. +type FineTunedCompletionResponse struct { + ID string `json:"id"` + Object objects.Object `json:"object"` + Created uint64 `json:"created"` + Model models.FineTunedModel `json:"model"` + Choices []*CompletionChoice `json:"choices"` + Usage *Usage `json:"usage"` +} + // CreateCompletion creates a completion for the provided prompt and parameters. func (c *Client) CreateCompletion(ctx context.Context, cr *CompletionRequest) (*CompletionResponse, error) { var b, err = c.post(ctx, routes.Completions, cr) @@ -131,3 +228,18 @@ func (c *Client) CreateCompletion(ctx context.Context, cr *CompletionRequest) (* return resp, nil } + +// CreateFineTunedCompletion creates a completion for the provided prompt and parameters, using a fine-tuned model. +func (c *Client) CreateFineTunedCompletion(ctx context.Context, cr *FineTunedCompletionRequest) (*FineTunedCompletionResponse, error) { + var b, err = c.post(ctx, routes.Completions, cr) + if err != nil { + return nil, err + } + + var resp = &FineTunedCompletionResponse{} + if err = json.Unmarshal(b, resp); err != nil { + return nil, err + } + + return resp, nil +} diff --git a/fine_tunes.go b/fine_tunes.go index 56a1e3a..3b9e178 100644 --- a/fine_tunes.go +++ b/fine_tunes.go @@ -101,12 +101,12 @@ type Event struct { // FineTuneResponse is the response from fine-tunes endpoints. type FineTuneResponse struct { - ID string `json:"id"` - Object objects.Object `json:"object"` - Model models.FineTune `json:"model"` - CreatedAt uint64 `json:"created_at"` - Events []*Event `json:"events,omitempty"` - FineTunedModel *string `json:"fine_tuned_model"` + ID string `json:"id"` + Object objects.Object `json:"object"` + Model models.FineTune `json:"model"` + CreatedAt uint64 `json:"created_at"` + Events []*Event `json:"events,omitempty"` + FineTunedModel *models.FineTunedModel `json:"fine_tuned_model"` Hyperparams struct { BatchSize int `json:"batch_size"` LearningRateMultiplier float64 `json:"learning_rate_multiplier"` diff --git a/images.go b/images.go index 0f46817..aaa9232 100644 --- a/images.go +++ b/images.go @@ -12,7 +12,20 @@ import ( type CreateImageRequest struct { // Prompt is a text description of the desired image(s). The maximum length is 1000 characters. Prompt string `json:"prompt"` - ImageRequestFields + // N specifies the number of images to generate. Must be between 1 and 10. + // Defaults to 1. + N int `json:"n,omitempty"` + // Size specifies the size of the generated images. Must be one of images.Size256x256, images.Size512x512, or + // images.Size1024x1024. + // Defaults to images.Size1024x1024. + Size images.Size `json:"size,omitempty"` + // ResponseFormat specifies the format in which the generated images are returned. Must be one of images.FormatURL + // or images.FormatB64JSON. + // Defaults to images.FormatURL. + ResponseFormat images.Format `json:"response_format,omitempty"` + // User specifies a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse: + // https://beta.openai.com/docs/guides/safety-best-practices/end-user-ids. + User string `json:"user,omitempty"` } // EditImageRequest contains all relevant fields for requests to the images/edits endpoint. @@ -25,18 +38,26 @@ type EditImageRequest struct { Mask string `json:"mask,omitempty"` // Prompt is a text description of the desired image(s). The maximum length is 1000 characters. Prompt string `json:"prompt"` - ImageRequestFields + // N specifies the number of images to generate. Must be between 1 and 10. + // Defaults to 1. + N int `json:"n,omitempty"` + // Size specifies the size of the generated images. Must be one of images.Size256x256, images.Size512x512, or + // images.Size1024x1024. + // Defaults to images.Size1024x1024. + Size images.Size `json:"size,omitempty"` + // ResponseFormat specifies the format in which the generated images are returned. Must be one of images.FormatURL + // or images.FormatB64JSON. + // Defaults to images.FormatURL. + ResponseFormat images.Format `json:"response_format,omitempty"` + // User specifies a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse: + // https://beta.openai.com/docs/guides/safety-best-practices/end-user-ids. + User string `json:"user,omitempty"` } // VariationImageRequest contains all relevant fields for requests to the images/variations endpoint. type VariationImageRequest struct { // Image is the image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square. Image string `json:"image"` - ImageRequestFields -} - -// ImageRequestFields contains the common fields for all images endpoints. -type ImageRequestFields struct { // N specifies the number of images to generate. Must be between 1 and 10. // Defaults to 1. N int `json:"n,omitempty"` diff --git a/models/fine_tunes.go b/models/fine_tunes.go index 4e4caac..409af8a 100644 --- a/models/fine_tunes.go +++ b/models/fine_tunes.go @@ -57,3 +57,12 @@ var stringToFineTune = map[string]FineTune{ "ada": Ada, "babbage": Babbage, } + +// FineTunedModel represents the name of a fine-tuned model which was +// previously generated. +type FineTunedModel string + +// NewFineTunedModel converts a string to FineTunedModel. +func NewFineTunedModel(name string) FineTunedModel { + return FineTunedModel(name) +} From 14a6cb74e440a74a975f9ec5617cec248d61377b Mon Sep 17 00:00:00 2001 From: Andy Day Date: Wed, 18 Jan 2023 21:02:00 -0800 Subject: [PATCH 28/30] . --- client_test.go | 10 ++--- completions.go | 113 ++++--------------------------------------------- 2 files changed, 13 insertions(+), 110 deletions(-) diff --git a/client_test.go b/client_test.go index 8672186..0f163bc 100644 --- a/client_test.go +++ b/client_test.go @@ -97,7 +97,7 @@ func TestCompletions(t *testing.T) { var client, _ = newTestClient(ts.URL) - var _, err = client.CreateCompletion(context.Background(), &CompletionRequest{ + var _, err = client.CreateCompletion(context.Background(), &CompletionRequest[models.Completion]{ Prompt: "Lorem ipsum", Model: models.TextDavinci003, MaxTokens: 5, @@ -221,12 +221,12 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) } - var completionReq *CompletionRequest + var completionReq *CompletionRequest[models.Completion] if completionReq, err = getCompletionBody(r); err != nil { http.Error(w, "could not read request", http.StatusInternalServerError) return } - res := &CompletionResponse{ + res := &CompletionResponse[models.Completion]{ ID: strconv.Itoa(int(time.Now().Unix())), Object: objects.TextCompletion, Created: uint64(time.Now().Unix()), @@ -300,8 +300,8 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { } // getCompletionBody Returns the body of the request to create a completion. -func getCompletionBody(r *http.Request) (*CompletionRequest, error) { - var completion = &CompletionRequest{} +func getCompletionBody(r *http.Request) (*CompletionRequest[models.Completion], error) { + var completion = &CompletionRequest[models.Completion]{} // read the request body reqBody, err := io.ReadAll(r.Body) if err != nil { diff --git a/completions.go b/completions.go index 0a0c60d..3b35d4e 100644 --- a/completions.go +++ b/completions.go @@ -10,95 +10,10 @@ import ( ) // CompletionRequest contains all relevant fields for requests to the completions endpoint. -type CompletionRequest struct { +type CompletionRequest[T models.Completion | models.FineTunedModel] struct { // Model specifies the ID of the model to use. // See more here: https://beta.openai.com/docs/models/overview - Model models.Completion `json:"model"` - // Prompt specifies the prompt(s) to generate completions for, encoded as a string, array of strings, array of - // tokens, or array of token arrays. Note that <|endoftext|> is the document separator that the model sees during - // training, so if a prompt is not specified the model will generate as if from the beginning of a new document. - // Defaults to <|endoftext|>. - Prompt string `json:"prompt,omitempty"` - // Suffix specifies the suffix that comes after a completion of inserted text. - // Defaults to null. - Suffix string `json:"suffix,omitempty"` - // MaxTokens specifies the maximum number of tokens to generate in the completion. The token count of your prompt - // plus max_tokens cannot exceed the model's context length. Most models have a context length of 2048 tokens - // (except for the newest models, which support 4096). - // Defaults to 16. - MaxTokens int `json:"max_tokens,omitempty"` - // Temperature specifies what sampling temperature to use. Higher values means the model will take more risks. Try - // 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer. OpenAI generally - // recommends altering this or top_p but not both. - // More on sampling temperature: https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277 - // Defaults to 1. - Temperature *float32 `json:"temperature,omitempty"` - // TopP specifies an alternative to sampling with temperature, called nucleus sampling, where the model considers - // the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% - // probability mass are considered. OpenAI generally recommends altering this or temperature but not both. - // Defaults to 1. - TopP *float32 `json:"top_p,omitempty"` - // N specifies how many completions to generate for each prompt. - // Note: Because this parameter generates many completions, it can quickly consume your token quota. Use carefully - // and ensure that you have reasonable settings for max_tokens and stop. - // Defaults to 1. - N int `json:"n,omitempty"` - // Steam specifies Whether to stream back partial progress. If set, tokens will be sent as data-only server-sent - // events as they become available, with the stream terminated by a data: [DONE] message. - // Defaults to false. - Stream bool `json:"stream,omitempty"` - // LogProbs specifies to include the log probabilities on the logprobs most likely tokens, as well the chosen - // tokens. For example, if logprobs is 5, the API will return a list of the 5 most likely tokens. The API will - // always return the logprob of the sampled token, so there may be up to logprobs+1 elements in the response. - // The maximum value for logprobs is 5. - // Defaults to null. - LogProbs *int `json:"logprobs,omitempty"` - // Echo specifies to echo back the prompt in addition to the completion. - // Defaults to false. - Echo bool `json:"echo,omitempty"` - // Stop specifies up to 4 sequences where the API will stop generating further tokens. The returned text will not - // contain the stop sequence. - Stop []string `json:"stop,omitempty"` - // PresencePenalty can be a number between -2.0 and 2.0. Positive values penalize new tokens based on whether they - // appear in the text so far, increasing the model's likelihood to talk about new topics. - // Defaults to 0. - PresencePenalty float32 `json:"presence_penalty,omitempty"` - // FrequencyPenalty can be a number between -2.0 and 2.0. Positive values penalize new tokens based on their - // existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. - // Defaults to 0. - FrequencyPenalty float32 `json:"frequency_penalty,omitempty"` - // Generates best_of completions server-side and returns the "best" (the one with the highest log probability per - // token). Results cannot be streamed. When used with n, best_of controls the number of candidate completions and n - // specifies how many to return – best_of must be greater than n. Note: Because this parameter generates many - // completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings - // for max_tokens and stop. - // Defaults to 1. - BestOf *int `json:"best_of,omitempty"` - // LogitBias modifies the likelihood of specified tokens appearing in the completion. Accepts a json object that - // maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. - // Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will - // vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like - // -100 or 100 should result in a ban or exclusive selection of the relevant token. - // As an example, you can pass {"50256": -100} to prevent the <|endoftext|> token from being generated. - // - // You can use this tokenizer tool to convert text to token IDs: - // https://beta.openai.com/tokenizer - // - // Defaults to null. - LogitBias map[string]int `json:"logit_bias,omitempty"` - // User is a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. - // See more here: https://beta.openai.com/docs/guides/safety-best-practices/end-user-ids - User string `json:"user,omitempty"` -} - -// FineTunedCompletionRequest contains all relevant fields for requests to the completions endpoint, -// using a fine-tuned model. -// -// Note: This seems completely redundant, and may change (or be removed) if a better / simpler solution arrises. -type FineTunedCompletionRequest struct { - // Model specifies the ID of the model to use. - // See more here: https://beta.openai.com/docs/models/overview - Model models.FineTunedModel `json:"model"` + Model T `json:"model"` // Prompt specifies the prompt(s) to generate completions for, encoded as a string, array of strings, array of // tokens, or array of token arrays. Note that <|endoftext|> is the document separator that the model sees during // training, so if a prompt is not specified the model will generate as if from the beginning of a new document. @@ -193,35 +108,23 @@ type LogprobResult struct { } // CompletionResponse is the response from the completions endpoint. -type CompletionResponse struct { +type CompletionResponse[T models.Completion | models.FineTunedModel] struct { ID string `json:"id"` Object objects.Object `json:"object"` Created uint64 `json:"created"` - Model models.Completion `json:"model"` + Model T `json:"model"` Choices []*CompletionChoice `json:"choices"` Usage *Usage `json:"usage"` } -// FineTunedCompletionResponse is the response from the completions endpoint. -// -// Note: This seems completely redundant, and may change (or be removed) if a better / simpler solution arrises. -type FineTunedCompletionResponse struct { - ID string `json:"id"` - Object objects.Object `json:"object"` - Created uint64 `json:"created"` - Model models.FineTunedModel `json:"model"` - Choices []*CompletionChoice `json:"choices"` - Usage *Usage `json:"usage"` -} - // CreateCompletion creates a completion for the provided prompt and parameters. -func (c *Client) CreateCompletion(ctx context.Context, cr *CompletionRequest) (*CompletionResponse, error) { +func (c *Client) CreateCompletion(ctx context.Context, cr *CompletionRequest[models.Completion]) (*CompletionResponse[models.Completion], error) { var b, err = c.post(ctx, routes.Completions, cr) if err != nil { return nil, err } - var resp = &CompletionResponse{} + var resp = &CompletionResponse[models.Completion]{} if err = json.Unmarshal(b, resp); err != nil { return nil, err } @@ -230,13 +133,13 @@ func (c *Client) CreateCompletion(ctx context.Context, cr *CompletionRequest) (* } // CreateFineTunedCompletion creates a completion for the provided prompt and parameters, using a fine-tuned model. -func (c *Client) CreateFineTunedCompletion(ctx context.Context, cr *FineTunedCompletionRequest) (*FineTunedCompletionResponse, error) { +func (c *Client) CreateFineTunedCompletion(ctx context.Context, cr *CompletionRequest[models.FineTunedModel]) (*CompletionResponse[models.FineTunedModel], error) { var b, err = c.post(ctx, routes.Completions, cr) if err != nil { return nil, err } - var resp = &FineTunedCompletionResponse{} + var resp = &CompletionResponse[models.FineTunedModel]{} if err = json.Unmarshal(b, resp); err != nil { return nil, err } From 91d5896f69acda3e970f783bbb34249dc1e6016c Mon Sep 17 00:00:00 2001 From: Andy Day Date: Wed, 18 Jan 2023 21:05:11 -0800 Subject: [PATCH 29/30] . --- client.go | 2 +- client_test.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index ad30be6..b7fdf3b 100644 --- a/client.go +++ b/client.go @@ -20,7 +20,7 @@ const ( basePath = "v1" ) -// Client is OpenAI GPT-3 API client. +// Client is OpenAI API client. type Client struct { token string orgID *string diff --git a/client_test.go b/client_test.go index 0f163bc..0ac056e 100644 --- a/client_test.go +++ b/client_test.go @@ -372,7 +372,7 @@ func OpenAITestServer() *httptest.Server { return case "/v1/images/generations": handleImageEndpoint(w, r) - // TODO: implement the other endpoints + // TODO: Implement the other endpoints. default: // the endpoint doesn't exist http.Error(w, "the resource path doesn't exist", http.StatusNotFound) From 1fa3e8cd5f2266cbbae31d0b47df1a21a643febf Mon Sep 17 00:00:00 2001 From: Andy Day Date: Wed, 18 Jan 2023 21:07:37 -0800 Subject: [PATCH 30/30] . --- .golangci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.golangci.yml b/.golangci.yml index ea4ba05..ee39860 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -188,7 +188,7 @@ linters: - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - goprintffuncname # Checks that printf-like functions are named with f at the end - gosec # Inspects source code for security problems - - lll # Reports long lines + # - lll # Reports long lines - makezero # Finds slice declarations with non-zero initial length # - nakedret # Finds naked returns in functions greater than a specified function length - nestif # Reports deeply nested if statements