Skip to content

Commit c0e6f0b

Browse files
committed
Add docker protocol support for llama-server model loading
To pull and run models via: llama-server -d ai/smollm2:135M-Q4_K_M Signed-off-by: Eric Curtin <[email protected]>
1 parent 4f63cd7 commit c0e6f0b

File tree

2 files changed

+170
-52
lines changed

2 files changed

+170
-52
lines changed

common/arg.cpp

Lines changed: 169 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,12 @@ static bool curl_perform_with_retry(const std::string & url, CURL * curl, int ma
245245
}
246246

247247
// download one single file from remote URL to local path
248-
static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, bool offline) {
248+
static bool common_download_file_single(const std::string & url,
249+
const std::string & path,
250+
const std::string & bearer_token,
251+
bool offline,
252+
bool is_docker = false) {
253+
// Standard download logic for non-docker files
249254
// Check if the file already exists locally
250255
auto file_exists = std::filesystem::exists(path);
251256

@@ -256,7 +261,7 @@ static bool common_download_file_single(const std::string & url, const std::stri
256261
std::string last_modified;
257262

258263
if (file_exists) {
259-
if (offline) {
264+
if (offline || is_docker) { // to be implemented, check modification of docker
260265
LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
261266
return true; // skip verification/downloading
262267
}
@@ -306,6 +311,7 @@ static bool common_download_file_single(const std::string & url, const std::stri
306311
// Set the URL, allow to follow http redirection
307312
curl_easy_setopt(curl.get(), CURLOPT_URL, url.c_str());
308313
curl_easy_setopt(curl.get(), CURLOPT_FOLLOWLOCATION, 1L);
314+
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
309315

310316
http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp");
311317
// Check if hf-token or bearer-token was specified
@@ -321,64 +327,67 @@ static bool common_download_file_single(const std::string & url, const std::stri
321327
curl_easy_setopt(curl.get(), CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA);
322328
#endif
323329

324-
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
325-
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
326-
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
330+
if (!is_docker) {
331+
typedef size_t (*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
332+
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
333+
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;
327334

328-
static std::regex header_regex("([^:]+): (.*)\r\n");
329-
static std::regex etag_regex("ETag", std::regex_constants::icase);
330-
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
335+
static std::regex header_regex("([^:]+): (.*)\r\n");
336+
static std::regex etag_regex("ETag", std::regex_constants::icase);
337+
static std::regex last_modified_regex("Last-Modified", std::regex_constants::icase);
331338

332-
std::string header(buffer, n_items);
333-
std::smatch match;
334-
if (std::regex_match(header, match, header_regex)) {
335-
const std::string & key = match[1];
336-
const std::string & value = match[2];
337-
if (std::regex_match(key, match, etag_regex)) {
338-
headers->etag = value;
339-
} else if (std::regex_match(key, match, last_modified_regex)) {
340-
headers->last_modified = value;
339+
std::string header(buffer, n_items);
340+
std::smatch match;
341+
if (std::regex_match(header, match, header_regex)) {
342+
const std::string & key = match[1];
343+
const std::string & value = match[2];
344+
if (std::regex_match(key, match, etag_regex)) {
345+
headers->etag = value;
346+
} else if (std::regex_match(key, match, last_modified_regex)) {
347+
headers->last_modified = value;
348+
}
341349
}
342-
}
343-
return n_items;
344-
};
350+
return n_items;
351+
};
345352

346-
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
347-
curl_easy_setopt(curl.get(), CURLOPT_NOPROGRESS, 1L); // hide head request progress
348-
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
349-
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
353+
curl_easy_setopt(curl.get(), CURLOPT_NOBODY, 1L); // will trigger the HEAD verb
354+
curl_easy_setopt(curl.get(), CURLOPT_HEADERFUNCTION, static_cast<CURLOPT_HEADERFUNCTION_PTR>(header_callback));
355+
curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers);
350356

351-
// we only allow retrying once for HEAD requests
352-
// this is for the use case of using running offline (no internet), retrying can be annoying
353-
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD");
354-
if (!was_perform_successful) {
355-
head_request_ok = false;
356-
}
357+
// we only allow retrying once for HEAD requests
358+
// this is for the use case of using running offline (no internet), retrying can be annoying
359+
bool was_perform_successful = curl_perform_with_retry(url, curl.get(), 1, 0, "HEAD");
360+
if (!was_perform_successful) {
361+
head_request_ok = false;
362+
}
357363

358-
long http_code = 0;
359-
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
360-
if (http_code == 200) {
361-
head_request_ok = true;
362-
} else {
363-
LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
364-
head_request_ok = false;
365-
}
364+
long http_code = 0;
365+
curl_easy_getinfo(curl.get(), CURLINFO_RESPONSE_CODE, &http_code);
366+
if (http_code == 200) {
367+
head_request_ok = true;
368+
} else {
369+
LOG_WRN("%s: HEAD invalid http status code received: %ld\n", __func__, http_code);
370+
head_request_ok = false;
371+
}
366372

367-
// if head_request_ok is false, we don't have the etag or last-modified headers
368-
// we leave should_download as-is, which is true if the file does not exist
369-
if (head_request_ok) {
370-
// check if ETag or Last-Modified headers are different
371-
// if it is, we need to download the file again
372-
if (!etag.empty() && etag != headers.etag) {
373-
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(), headers.etag.c_str());
374-
should_download = true;
375-
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
376-
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, last_modified.c_str(), headers.last_modified.c_str());
377-
should_download = true;
373+
// if head_request_ok is false, we don't have the etag or last-modified headers
374+
// we leave should_download as-is, which is true if the file does not exist
375+
if (head_request_ok) {
376+
// check if ETag or Last-Modified headers are different
377+
// if it is, we need to download the file again
378+
if (!etag.empty() && etag != headers.etag) {
379+
LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, etag.c_str(),
380+
headers.etag.c_str());
381+
should_download = true;
382+
} else if (!last_modified.empty() && last_modified != headers.last_modified) {
383+
LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__,
384+
last_modified.c_str(), headers.last_modified.c_str());
385+
should_download = true;
386+
}
378387
}
379388
}
380389

381-
if (should_download) {
390+
if (should_download || is_docker) {
382391
std::string path_temporary = path + ".downloadInProgress";
383392
if (file_exists) {
384393
LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str());
@@ -712,7 +721,11 @@ bool common_has_curl() {
712721
return false;
713722
}
714723

715-
static bool common_download_file_single(const std::string &, const std::string &, const std::string &, bool) {
724+
static bool common_download_file_single(const std::string &,
725+
const std::string &,
726+
const std::string &,
727+
bool,
728+
bool = false) {
716729
LOG_ERR("error: built without CURL, cannot download model from internet\n");
717730
return false;
718731
}
@@ -745,6 +758,101 @@ std::pair<long, std::vector<char>> common_remote_get_content(const std::string &
745758

746759
#endif // LLAMA_USE_CURL
747760

761+
//
762+
// Docker registry functions
763+
//
764+
765+
static std::string common_docker_get_token(const std::string & repo) {
766+
std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
767+
768+
common_remote_params params;
769+
auto res = common_remote_get_content(url, params);
770+
771+
if (res.first != 200) {
772+
throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
773+
}
774+
775+
std::string response_str(res.second.begin(), res.second.end());
776+
nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
777+
778+
if (!response.contains("token")) {
779+
throw std::runtime_error("Docker registry token response missing 'token' field");
780+
}
781+
782+
return response["token"].get<std::string>();
783+
}
784+
785+
static std::string common_docker_resolve_model(const std::string & docker) {
786+
// Parse ai/smollm2:135M-Q4_K_M
787+
size_t colon_pos = docker.find(':');
788+
std::string repo, tag;
789+
if (colon_pos != std::string::npos) {
790+
repo = docker.substr(0, colon_pos);
791+
tag = docker.substr(colon_pos + 1);
792+
} else {
793+
repo = docker;
794+
tag = "latest";
795+
}
796+
797+
LOG_INF("Downloading Docker Model: %s:%s\n", repo.c_str(), tag.c_str());
798+
try {
799+
std::string token = common_docker_get_token(repo); // Get authentication token
800+
801+
// Get manifest
802+
std::string manifest_url = "https://registry-1.docker.io/v2/" + repo + "/manifests/" + tag;
803+
common_remote_params manifest_params;
804+
manifest_params.headers.push_back("Authorization: Bearer " + token);
805+
manifest_params.headers.push_back(
806+
"Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json");
807+
auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
808+
if (manifest_res.first != 200) {
809+
throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
810+
}
811+
812+
std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
813+
nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
814+
std::string gguf_digest; // Find the GGUF layer
815+
if (manifest.contains("layers")) {
816+
for (const auto & layer : manifest["layers"]) {
817+
if (layer.contains("mediaType")) {
818+
std::string media_type = layer["mediaType"].get<std::string>();
819+
if (media_type == "application/vnd.docker.ai.gguf.v3" ||
820+
media_type.find("gguf") != std::string::npos) {
821+
gguf_digest = layer["digest"].get<std::string>();
822+
break;
823+
}
824+
}
825+
}
826+
}
827+
828+
if (gguf_digest.empty()) {
829+
throw std::runtime_error("No GGUF layer found in Docker manifest");
830+
}
831+
832+
// Prepare local filename
833+
std::string model_filename = repo;
834+
std::replace(model_filename.begin(), model_filename.end(), '/', '_');
835+
model_filename += "_" + tag + ".gguf";
836+
std::string local_path = fs_get_cache_file(model_filename);
837+
if (std::filesystem::exists(local_path)) { // Check if already downloaded
838+
LOG_INF("Docker Model already cached: %s\n", local_path.c_str());
839+
return local_path;
840+
}
841+
842+
// Download the blob using common_download_file_single with is_docker=true
843+
std::string blob_url = "https://registry-1.docker.io/v2/" + repo + "/blobs/" + gguf_digest;
844+
if (!common_download_file_single(blob_url, local_path, token, false, true)) {
845+
throw std::runtime_error("Failed to download Docker blob");
846+
}
847+
848+
LOG_INF("Downloaded Docker Model to: %s\n", local_path.c_str());
849+
return local_path;
850+
} catch (const std::exception & e) {
851+
LOG_ERR("Docker Model download failed: %s\n", e.what());
852+
throw;
853+
}
854+
}
855+
748856
//
749857
// utils
750858
//
@@ -795,7 +903,9 @@ static handle_model_result common_params_handle_model(
795903
handle_model_result result;
796904
// handle pre-fill default model path and url based on hf_repo and hf_file
797905
{
798-
if (!model.hf_repo.empty()) {
906+
if (!model.docker.empty()) { // Handle Docker URLs by resolving them to local paths
907+
model.path = common_docker_resolve_model(model.docker);
908+
} else if (!model.hf_repo.empty()) {
799909
// short-hand to avoid specifying --hf-file -> default it to --model
800910
if (model.hf_file.empty()) {
801911
if (model.path.empty()) {
@@ -2636,6 +2746,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
26362746
params.model.url = value;
26372747
}
26382748
).set_env("LLAMA_ARG_MODEL_URL"));
2749+
add_opt(common_arg(
2750+
{ "-d", "-dr", "--docker", "--docker-repo" }, "<repo>/<model>[:quant]",
2751+
"Docker Hub model repository; quant is optional, default to latest.\n"
2752+
"example: ai/smollm2:135M-Q4_K_M\n"
2753+
"(default: unused)",
2754+
[](common_params & params, const std::string & value) { params.model.docker = value; })
2755+
.set_env("LLAMA_ARG_DOCKER"));
26392756
add_opt(common_arg(
26402757
{"-hf", "-hfr", "--hf-repo"}, "<user>/<model>[:quant]",
26412758
"Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n"

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ struct common_params_model {
197197
std::string url = ""; // model url to download // NOLINT
198198
std::string hf_repo = ""; // HF repo // NOLINT
199199
std::string hf_file = ""; // HF file // NOLINT
200+
std::string docker = ""; // Docker Model url to download // NOLINT
200201
};
201202

202203
struct common_params_speculative {

0 commit comments

Comments
 (0)