diff --git a/include/ollama.hpp b/include/ollama.hpp index 73f0e90..fc540ed 100644 --- a/include/ollama.hpp +++ b/include/ollama.hpp @@ -263,13 +263,14 @@ namespace ollama request(): json() {} ~request(){}; - static ollama::request from_embedding(const std::string& name, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m") + static ollama::request from_embedding(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate=true, const std::string& keep_alive_duration="5m") { ollama::request request(message_type::embedding); - request["model"] = name; - request["prompt"] = prompt; + request["model"] = model; + request["input"] = input; if (options!=nullptr) request["options"] = options["options"]; + request["truncate"] = truncate; request["keep_alive"] = keep_alive_duration; return request; @@ -295,7 +296,7 @@ namespace ollama if (type==message_type::generation && json_data.contains("response")) simple_string=json_data["response"].get(); else - if (type==message_type::embedding && json_data.contains("embedding")) simple_string=json_data["embedding"].get(); + if (type==message_type::embedding && json_data.contains("embeddings")) simple_string=json_data["embeddings"].get(); else if (type==message_type::chat && json_data.contains("message")) simple_string=json_data["message"]["content"].get(); @@ -715,15 +716,15 @@ class Ollama return false; } - ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m") + ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m") { - ollama::request request = ollama::request::from_embedding(model, prompt, options, keep_alive_duration); + ollama::request request = ollama::request::from_embedding(model, input, options, truncate, keep_alive_duration); ollama::response response; std::string request_string = request.dump(); if (ollama::log_requests) std::cout << request_string << std::endl; - if (auto res = cli->Post("/api/embeddings", request_string, "application/json")) + if (auto res = cli->Post("/api/embed", request_string, "application/json")) { if (ollama::log_replies) std::cout << res->body << std::endl; @@ -885,9 +886,9 @@ namespace ollama return ollama.push_model(model, allow_insecure); } - inline ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m") + inline ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m") { - return ollama.generate_embeddings(model, prompt, options, keep_alive_duration); + return ollama.generate_embeddings(model, input, options, truncate, keep_alive_duration); } inline void setReadTimeout(const int& seconds) diff --git a/singleheader/ollama.hpp b/singleheader/ollama.hpp index 88ac415..d1aff5d 100644 --- a/singleheader/ollama.hpp +++ b/singleheader/ollama.hpp @@ -35053,13 +35053,14 @@ namespace ollama request(): json() {} ~request(){}; - static ollama::request from_embedding(const std::string& name, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m") + static ollama::request from_embedding(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate=true, const std::string& keep_alive_duration="5m") { ollama::request request(message_type::embedding); - request["model"] = name; - request["prompt"] = prompt; + request["model"] = model; + request["input"] = input; if (options!=nullptr) request["options"] = options["options"]; + request["truncate"] = truncate; request["keep_alive"] = keep_alive_duration; return request; @@ -35085,7 +35086,7 @@ namespace ollama if (type==message_type::generation && json_data.contains("response")) simple_string=json_data["response"].get(); else - if (type==message_type::embedding && json_data.contains("embedding")) simple_string=json_data["embedding"].get(); + if (type==message_type::embedding && json_data.contains("embeddings")) simple_string=json_data["embeddings"].get(); else if (type==message_type::chat && json_data.contains("message")) simple_string=json_data["message"]["content"].get(); @@ -35505,15 +35506,15 @@ class Ollama return false; } - ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m") + ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m") { - ollama::request request = ollama::request::from_embedding(model, prompt, options, keep_alive_duration); + ollama::request request = ollama::request::from_embedding(model, input, options, truncate, keep_alive_duration); ollama::response response; std::string request_string = request.dump(); if (ollama::log_requests) std::cout << request_string << std::endl; - if (auto res = cli->Post("/api/embeddings", request_string, "application/json")) + if (auto res = cli->Post("/api/embed", request_string, "application/json")) { if (ollama::log_replies) std::cout << res->body << std::endl; @@ -35675,9 +35676,9 @@ namespace ollama return ollama.push_model(model, allow_insecure); } - inline ollama::response generate_embeddings(const std::string& model, const std::string& prompt, const json& options=nullptr, const std::string& keep_alive_duration="5m") + inline ollama::response generate_embeddings(const std::string& model, const std::string& input, const json& options=nullptr, bool truncate = true, const std::string& keep_alive_duration="5m") { - return ollama.generate_embeddings(model, prompt, options, keep_alive_duration); + return ollama.generate_embeddings(model, input, options, truncate, keep_alive_duration); } inline void setReadTimeout(const int& seconds) diff --git a/test/test.cpp b/test/test.cpp index 2e4e71f..10cf894 100644 --- a/test/test.cpp +++ b/test/test.cpp @@ -12,6 +12,8 @@ // Note that this is static. We will use these options for other generations. static ollama::options options; +static std::string test_model = "llama3:8b", image_test_model = "llava"; + TEST_SUITE("Ollama Tests") { TEST_CASE("Initialize Options") { @@ -52,19 +54,19 @@ TEST_SUITE("Ollama Tests") { TEST_CASE("Load Model") { - CHECK( ollama::load_model("llama3:8b") ); + CHECK( ollama::load_model(test_model) ); } TEST_CASE("Pull, Copy, and Delete Models") { // Pull a model by specifying a model name. - CHECK( ollama::pull_model("llama3:8b") == true ); + CHECK( ollama::pull_model(test_model) == true ); // Copy a model by specifying a source model and destination model name. - CHECK( ollama::copy_model("llama3:8b", "llama3_copy") ==true ); + CHECK( ollama::copy_model(test_model, test_model+"_copy") ==true ); // Delete a model by specifying a model name. - CHECK( ollama::delete_model("llama3_copy") == true ); + CHECK( ollama::delete_model(test_model+"_copy") == true ); } TEST_CASE("Model Info") { @@ -81,7 +83,7 @@ TEST_SUITE("Ollama Tests") { // List the models available locally in the ollama server std::vector models = ollama::list_models(); - bool contains_model = (std::find(models.begin(), models.end(), "llama3:8b") != models.end() ); + bool contains_model = (std::find(models.begin(), models.end(), test_model) != models.end() ); CHECK( contains_model ); } @@ -101,12 +103,9 @@ TEST_SUITE("Ollama Tests") { TEST_CASE("Basic Generation") { - ollama::response response = ollama::generate("llama3:8b", "Why is the sky blue?", options); - //std::cout << response << std::endl; - - std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,"; + ollama::response response = ollama::generate(test_model, "Why is the sky blue?", options); - CHECK(response.as_simple_string() == expected_response); + CHECK( response.as_json().contains("response") == true ); } @@ -124,11 +123,11 @@ TEST_SUITE("Ollama Tests") { TEST_CASE("Streaming Generation") { std::function response_callback = on_receive_response; - ollama::generate("llama3:8b", "Why is the sky blue?", response_callback, options); + ollama::generate(test_model, "Why is the sky blue?", response_callback, options); std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,"; - CHECK( streamed_response == expected_response ); + CHECK( streamed_response != "" ); } TEST_CASE("Non-Singleton Generation") { @@ -136,23 +135,22 @@ TEST_SUITE("Ollama Tests") { Ollama my_ollama_server("http://localhost:11434"); // You can use all of the same functions from this instanced version of the class. - ollama::response response = my_ollama_server.generate("llama3:8b", "Why is the sky blue?", options); - //std::cout << response << std::endl; + ollama::response response = my_ollama_server.generate(test_model, "Why is the sky blue?", options); std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,"; - CHECK(response.as_simple_string() == expected_response); + CHECK(response.as_json().contains("response") == true); } TEST_CASE("Single-Message Chat") { ollama::message message("user", "Why is the sky blue?"); - ollama::response response = ollama::chat("llama3:8b", message, options); + ollama::response response = ollama::chat(test_model, message, options); std::string expected_response = "What a great question!\n\nThe sky appears blue because of a phenomenon called Rayleigh scattering,"; - CHECK(response.as_simple_string()!=""); + CHECK(response.as_json().contains("message") == true); } TEST_CASE("Multi-Message Chat") { @@ -163,11 +161,11 @@ TEST_SUITE("Ollama Tests") { ollama::messages messages = {message1, message2, message3}; - ollama::response response = ollama::chat("llama3:8b", messages, options); + ollama::response response = ollama::chat(test_model, messages, options); std::string expected_response = ""; - CHECK(response.as_simple_string()!=""); + CHECK(response.as_json().contains("message") == true); } TEST_CASE("Chat with Streaming Response") { @@ -182,7 +180,7 @@ TEST_SUITE("Ollama Tests") { ollama::message message("user", "Why is the sky blue?"); - ollama::chat("llama3:8b", message, response_callback, options); + ollama::chat(test_model, message, response_callback, options); CHECK(streamed_response!=""); } @@ -195,12 +193,9 @@ TEST_SUITE("Ollama Tests") { ollama::image image = ollama::image::from_file("llama.jpg"); - //ollama::images images={image}; - - ollama::response response = ollama::generate("llava", "What do you see in this image?", options, image); - std::string expected_response = " The image features a large, fluffy white llama"; + ollama::response response = ollama::generate(image_test_model, "What do you see in this image?", options, image); - CHECK(response.as_simple_string() == expected_response); + CHECK( response.as_json().contains("response") == true ); } TEST_CASE("Generation with Multiple Images") { @@ -214,10 +209,10 @@ TEST_SUITE("Ollama Tests") { ollama::images images={image, base64_image}; - ollama::response response = ollama::generate("llava", "What do you see in this image?", options, images); + ollama::response response = ollama::generate(image_test_model, "What do you see in this image?", options, images); std::string expected_response = " The image features a large, fluffy white and gray llama"; - CHECK(response.as_simple_string() == expected_response); + CHECK(response.as_json().contains("response") == true); } TEST_CASE("Chat with Image") { @@ -230,21 +225,20 @@ TEST_SUITE("Ollama Tests") { // We can optionally include images with each message. Vision-enabled models will be able to utilize these. ollama::message message_with_image("user", "What do you see in this image?", image); - ollama::response response = ollama::chat("llava", message_with_image, options); + ollama::response response = ollama::chat(image_test_model, message_with_image, options); std::string expected_response = " The image features a large, fluffy white llama"; - CHECK(response.as_simple_string()!=""); + CHECK(response.as_json().contains("message") == true); } TEST_CASE("Embedding Generation") { options["num_predict"] = 18; - ollama::response response = ollama::generate_embeddings("llama3:8b", "Why is the sky blue?"); - //std::cout << response << std::endl; + ollama::response response = ollama::generate_embeddings(test_model, "Why is the sky blue?"); - CHECK(response.as_json().contains("embedding") == true); + CHECK(response.as_json().contains("embeddings") == true); } TEST_CASE("Enable Debug Logging") {