Skip to content

Commit 102102c

Browse files
committed
sam : image + prompt encoder, store embeddings
1 parent f6365c0 commit 102102c

File tree

9 files changed

+9894
-20
lines changed

9 files changed

+9894
-20
lines changed

examples/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ add_subdirectory(dolly-v2)
2727
add_subdirectory(replit)
2828
add_subdirectory(mpt)
2929
add_subdirectory(starcoder)
30+
add_subdirectory(sam)

examples/common.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,3 +755,46 @@ float similarity(const std::string & s0, const std::string & s1) {
755755

756756
return 1.0f - (dist / std::max(s0.size(), s1.size()));
757757
}
758+
759+
bool sam_params_parse(int argc, char ** argv, sam_params & params) {
760+
for (int i = 1; i < argc; i++) {
761+
std::string arg = argv[i];
762+
763+
if (arg == "-s" || arg == "--seed") {
764+
params.seed = std::stoi(argv[++i]);
765+
} else if (arg == "-t" || arg == "--threads") {
766+
params.n_threads = std::stoi(argv[++i]);
767+
} else if (arg == "-m" || arg == "--model") {
768+
params.model = argv[++i];
769+
} else if (arg == "-i" || arg == "--inp") {
770+
params.fname_inp = argv[++i];
771+
} else if (arg == "-o" || arg == "--out") {
772+
params.fname_out = argv[++i];
773+
} else if (arg == "-h" || arg == "--help") {
774+
sam_print_usage(argc, argv, params);
775+
exit(0);
776+
} else {
777+
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
778+
sam_print_usage(argc, argv, params);
779+
exit(0);
780+
}
781+
}
782+
783+
return true;
784+
}
785+
786+
void sam_print_usage(int argc, char ** argv, const sam_params & params) {
787+
fprintf(stderr, "usage: %s [options]\n", argv[0]);
788+
fprintf(stderr, "\n");
789+
fprintf(stderr, "options:\n");
790+
fprintf(stderr, " -h, --help show this help message and exit\n");
791+
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n");
792+
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
793+
fprintf(stderr, " -m FNAME, --model FNAME\n");
794+
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
795+
fprintf(stderr, " -i FNAME, --inp FNAME\n");
796+
fprintf(stderr, " input file (default: %s)\n", params.fname_inp.c_str());
797+
fprintf(stderr, " -o FNAME, --out FNAME\n");
798+
fprintf(stderr, " output file (default: %s)\n", params.fname_out.c_str());
799+
fprintf(stderr, "\n");
800+
}

examples/common.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#define COMMON_SAMPLE_RATE 16000
1212

1313
//
14-
// CLI argument parsing
14+
// GPT CLI argument parsing
1515
//
1616

1717
struct gpt_params {
@@ -155,3 +155,20 @@ bool vad_simple(
155155

156156
// compute similarity between two strings using Levenshtein distance
157157
float similarity(const std::string & s0, const std::string & s1);
158+
159+
//
160+
// SAM argument parsing
161+
//
162+
163+
struct sam_params {
164+
int32_t seed = -1; // RNG seed
165+
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
166+
167+
std::string model = "models/sam-vit-b/ggml-model-f16.bin"; // model path
168+
std::string fname_inp = "img.jpg";
169+
std::string fname_out = "img.out";
170+
};
171+
172+
bool sam_params_parse(int argc, char ** argv, sam_params & params);
173+
174+
void sam_print_usage(int argc, char ** argv, const sam_params & params);

examples/sam/CMakeLists.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
#
2+
# sam
3+
4+
set(TEST_TARGET sam)
5+
add_executable(${TEST_TARGET} main.cpp)
6+
target_link_libraries(${TEST_TARGET} PRIVATE ggml common)
7+
8+
#
9+
# sam-quantize
10+
11+
#set(TEST_TARGET sam-quantize)
12+
#add_executable(${TEST_TARGET} quantize.cpp)
13+
#target_link_libraries(${TEST_TARGET} PRIVATE ggml common)
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Convert a SAM model checkpoint to a ggml compatible file
2+
#
3+
4+
import os
5+
import sys
6+
import code
7+
import json
8+
import torch
9+
import struct
10+
import numpy as np
11+
12+
if len(sys.argv) < 3:
13+
print("Usage: convert-pth-to-ggml.py file-model ftype\n")
14+
print(" ftype == 0 -> float32")
15+
print(" ftype == 1 -> float16")
16+
sys.exit(1)
17+
18+
# output in the same directory as the model
19+
fname_model = sys.argv[1]
20+
fname_out = os.path.dirname(fname_model) + "/ggml-model.bin"
21+
22+
# possible data types
23+
# ftype == 0 -> float32
24+
# ftype == 1 -> float16
25+
#
26+
# map from ftype to string
27+
ftype_str = ["f32", "f16"]
28+
29+
ftype = 1
30+
if len(sys.argv) > 2:
31+
ftype = int(sys.argv[2])
32+
33+
if ftype < 0 or ftype > 1:
34+
print("Invalid ftype: " + str(ftype))
35+
sys.exit(1)
36+
37+
fname_out = fname_out.replace(".bin", "-" + ftype_str[ftype] + ".bin")
38+
39+
model = torch.load(fname_model, map_location="cpu")
40+
41+
# TODO: determine based on model data
42+
# TODO: add decoder / prompt encoder if needed
43+
hparams = {
44+
"n_enc_state": 768,
45+
"n_enc_layers": 12,
46+
"n_enc_heads": 12,
47+
"n_enc_out_chans": 256,
48+
49+
"n_pt_embd": 4,
50+
}
51+
52+
print(hparams)
53+
54+
for k, v in model.items():
55+
print(k, v.shape)
56+
57+
#exit()
58+
#code.interact(local=locals())
59+
60+
fout = open(fname_out, "wb")
61+
62+
fout.write(struct.pack("i", 0x67676d6c)) # magic: ggml in hex
63+
fout.write(struct.pack("i", hparams["n_enc_state"]))
64+
fout.write(struct.pack("i", hparams["n_enc_layers"]))
65+
fout.write(struct.pack("i", hparams["n_enc_heads"]))
66+
fout.write(struct.pack("i", hparams["n_enc_out_chans"]))
67+
fout.write(struct.pack("i", hparams["n_pt_embd"]))
68+
fout.write(struct.pack("i", ftype))
69+
70+
for k, v in model.items():
71+
name = k
72+
shape = v.shape
73+
74+
# TODO: export only the Encoder -- after it works we will export the other stuff
75+
if name[:13] != "image_encoder" and \
76+
name[:14] != "prompt_encoder":
77+
continue
78+
79+
if name[:19] == "prompt_encoder.mask":
80+
continue
81+
82+
print("Processing variable: " + name + " with shape: ", shape, " and type: ", v.dtype)
83+
84+
#data = tf.train.load_variable(dir_model, name).squeeze()
85+
#data = v.numpy().squeeze()
86+
data = v.numpy()
87+
n_dims = len(data.shape);
88+
89+
# for efficiency - transpose some matrices
90+
# "model/h.*/attn/c_attn/w"
91+
# "model/h.*/attn/c_proj/w"
92+
# "model/h.*/mlp/c_fc/w"
93+
# "model/h.*/mlp/c_proj/w"
94+
#if name[-14:] == "/attn/c_attn/w" or \
95+
# name[-14:] == "/attn/c_proj/w" or \
96+
# name[-11:] == "/mlp/c_fc/w" or \
97+
# name[-13:] == "/mlp/c_proj/w":
98+
# print(" Transposing")
99+
# data = data.transpose()
100+
101+
dshape = data.shape
102+
103+
# default type is fp16
104+
ftype_cur = 1
105+
if ftype == 0 or n_dims == 1 or \
106+
name == "image_encoder.pos_embed" or \
107+
name.startswith("prompt_encoder"):
108+
print(" Converting to float32")
109+
data = data.astype(np.float32)
110+
ftype_cur = 0
111+
else:
112+
print(" Converting to float16")
113+
data = data.astype(np.float16)
114+
115+
# reshape the 1D bias into a 4D tensor so we can use ggml_repeat
116+
# keep it in F32 since the data is small
117+
if name == "image_encoder.patch_embed.proj.bias":
118+
data = data.reshape(1, data.shape[0], 1, 1)
119+
n_dims = len(data.shape);
120+
dshape = data.shape
121+
122+
print(" New shape: ", dshape)
123+
124+
# header
125+
str = name.encode('utf-8')
126+
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
127+
for i in range(n_dims):
128+
fout.write(struct.pack("i", dshape[n_dims - 1 - i]))
129+
fout.write(str);
130+
131+
# data
132+
data.tofile(fout)
133+
134+
fout.close()
135+
136+
print("Done. Output file: " + fname_out)
137+
print("")

0 commit comments

Comments
 (0)