|
9 | 9 | #include <algorithm> |
10 | 10 | #include <sstream> |
11 | 11 |
|
12 | | -llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) { |
13 | | - // clear empty sequences |
14 | | - // the previous ubatch is assumed to be gone, |
15 | | - // so nothing should refer to values in these sequences anymore. |
16 | | - for (size_t i = seq.size(); i-- > 0;) { |
17 | | - if (seq[i].length == 0) { |
18 | | - seq.pop_back(); |
19 | | - } else { |
20 | | - break; |
21 | | - } |
22 | | - } |
23 | | - |
24 | | - udatas.push_back({}); |
25 | | - |
26 | | - auto & udata = udatas.back(); |
27 | | - |
28 | | - udata.token.resize(!has_embd ? n_ubatch : 0); |
29 | | - udata.embd.resize(has_embd ? n_embd * n_ubatch : 0); |
30 | | - udata.pos.resize(n_ubatch); |
31 | | - udata.n_seq_id.resize(n_ubatch); |
32 | | - udata.seq_id.resize(n_ubatch); |
33 | | - udata.output.resize(n_ubatch); |
34 | | - |
35 | | - llama_ubatch ubatch = { |
36 | | - /*equal_seqs =*/ true, |
37 | | - /*n_tokens =*/ 0, |
38 | | - /*n_seq_tokens =*/ 0, |
39 | | - /*n_seqs =*/ 0, |
40 | | - /*token =*/ !has_embd ? udata.token.data() : nullptr, |
41 | | - /*embd =*/ has_embd ? udata.embd.data() : nullptr, |
42 | | - /*pos =*/ udata.pos.data(), |
43 | | - /*n_seq_id =*/ udata.n_seq_id.data(), |
44 | | - /*seq_id =*/ udata.seq_id.data(), |
45 | | - /*output =*/ udata.output.data(), |
46 | | - }; |
47 | | - |
48 | | - return ubatch; |
49 | | -} |
50 | | - |
51 | | -void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) { |
52 | | - GGML_ASSERT(batch != nullptr); |
53 | | - GGML_ASSERT(length <= seq.length); |
54 | | - // Can only add sequences of equal lengths to a batch, |
55 | | - // otherwise it isn't clear to which sequence a token belongs |
56 | | - GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs); |
57 | | - GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs); |
58 | | - // NOTE: loops are separated for cache-friendliness |
59 | | - if (batch->token) { |
60 | | - if (ubatch.equal_seqs) { |
61 | | - for (size_t i = 0; i < length; ++i) { |
62 | | - ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]]; |
63 | | - } |
64 | | - } else { |
65 | | - // simple split |
66 | | - ubatch.token = batch->token + seq.offset; |
67 | | - } |
68 | | - } else { |
69 | | - ubatch.token = nullptr; |
70 | | - } |
71 | | - if (batch->embd) { |
72 | | - if (ubatch.equal_seqs) { |
73 | | - for (size_t i = 0; i < length; ++i) { |
74 | | - memcpy( |
75 | | - ubatch.embd + (n_embd * (ubatch.n_tokens + i)), |
76 | | - batch->embd + (n_embd * ids[seq.offset + i]), |
77 | | - n_embd * sizeof(float) |
78 | | - ); |
79 | | - } |
80 | | - } else { |
81 | | - // simple split |
82 | | - ubatch.embd = batch->embd + (n_embd * seq.offset); |
83 | | - } |
84 | | - } else { |
85 | | - ubatch.embd = nullptr; |
86 | | - } |
87 | | - if (ubatch.equal_seqs) { |
88 | | - for (size_t i = 0; i < length; ++i) { |
89 | | - ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; |
90 | | - } |
91 | | - } else { |
92 | | - // simple split |
93 | | - ubatch.pos = batch->pos + seq.offset; |
94 | | - } |
95 | | - if (ubatch.equal_seqs) { |
96 | | - ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; |
97 | | - if (seq.seq_id) { |
98 | | - ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; |
99 | | - } |
100 | | - } else { |
101 | | - // simple split |
102 | | - if (batch->n_seq_id) { |
103 | | - ubatch.n_seq_id = batch->n_seq_id + seq.offset; |
104 | | - } else { |
105 | | - for (size_t i = 0; i < length; ++i) { |
106 | | - ubatch.n_seq_id[ubatch.n_seqs + i] = 1; |
107 | | - } |
108 | | - } |
109 | | - if (batch->seq_id) { |
110 | | - ubatch.seq_id = batch->seq_id + seq.offset; |
111 | | - } |
112 | | - } |
113 | | - if (batch->logits) { |
114 | | - if (ubatch.equal_seqs) { |
115 | | - for (size_t i = 0; i < length; ++i) { |
116 | | - size_t id = ids[seq.offset + i]; |
117 | | - int8_t is_output = batch->logits[id]; |
118 | | - ubatch.output[ubatch.n_tokens + i] = is_output; |
119 | | - if (is_output) { out_ids.push_back(id); } |
120 | | - } |
121 | | - } else { |
122 | | - // simple split |
123 | | - ubatch.output = batch->logits + seq.offset; |
124 | | - for (size_t i = 0; i < length; ++i) { |
125 | | - if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); } |
126 | | - } |
127 | | - } |
128 | | - } else { |
129 | | - // only get last output |
130 | | - for (size_t i = 0; i < length; ++i) { |
131 | | - size_t id = ids[seq.offset + i]; |
132 | | - int8_t is_last = id == ids.size() - 1; |
133 | | - ubatch.output[ubatch.n_tokens + i] = is_last; |
134 | | - if (is_last) { out_ids.push_back(id); } |
135 | | - } |
136 | | - } |
137 | | - if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) { |
138 | | - ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1; |
139 | | - } |
140 | | - ubatch.n_tokens += length; |
141 | | - ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits |
142 | | - seq.offset += length; |
143 | | - seq.length -= length; |
144 | | - n_tokens -= length; |
145 | | - GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs); |
146 | | -} |
147 | | - |
148 | | -llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) { |
149 | | - n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; |
150 | | - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); |
151 | | - ubatch.equal_seqs = false; |
152 | | - if (!seq.empty()) { |
153 | | - llama_sbatch_seq & s = seq[0]; |
154 | | - size_t length = s.length < n_ubatch ? s.length : n_ubatch; |
155 | | - GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits |
156 | | - add_seq_to_ubatch(ubatch, s, length); |
157 | | - } |
158 | | - return ubatch; |
159 | | -} |
160 | | - |
161 | | -llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) { |
162 | | - n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; |
163 | | - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); |
164 | | - if (!seq.empty()) { |
165 | | - size_t length = 0; |
166 | | - size_t n_tokens_in_ubatch = 0; |
167 | | - GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits |
168 | | - // smallest first, because it's easier to split this way; |
169 | | - // starting from the end to pop in constant time. |
170 | | - for (size_t i = seq.size(); i-- > 0;) { |
171 | | - llama_sbatch_seq & s = seq[i]; |
172 | | - GGML_ASSERT(s.length > 0); |
173 | | - if (length == 0) { |
174 | | - length = s.length < n_ubatch ? s.length : n_ubatch; |
175 | | - } |
176 | | - add_seq_to_ubatch(ubatch, s, length); |
177 | | - n_tokens_in_ubatch += length; |
178 | | - // shared prompts can't be mixed with any of their sequences, |
179 | | - // so it's safer to compute them in their own ubatch |
180 | | - if (s.n_seq_id > 1) { break; } |
181 | | - // stop when there isn't enough space for another sequence |
182 | | - if (length + n_tokens_in_ubatch > n_ubatch) { break; } |
183 | | - } |
184 | | - } |
185 | | - return ubatch; |
186 | | -} |
187 | | - |
188 | | -llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { |
189 | | - n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch; |
190 | | - llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr); |
191 | | - if (!seq.empty()) { |
192 | | - llama_sbatch_seq & s = seq[seq.size() - 1]; |
193 | | - size_t length = s.length < n_ubatch ? s.length : n_ubatch; |
194 | | - GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits |
195 | | - add_seq_to_ubatch(ubatch, s, length); |
196 | | - } |
197 | | - return ubatch; |
198 | | -} |
199 | | - |
200 | | -llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) { |
201 | | - GGML_ASSERT(batch.n_tokens >= 0); |
202 | | - this->batch = &batch; |
203 | | - this->n_embd = n_embd; |
204 | | - |
205 | | - n_tokens = batch.n_tokens; |
206 | | - ids.resize(n_tokens); |
207 | | - out_ids.clear(); |
208 | | - // TODO: reserve out_ids and seq |
209 | | - |
210 | | - for (size_t i = 0; i < n_tokens; ++i) { |
211 | | - ids[i] = i; |
212 | | - } |
213 | | - |
214 | | - if (simple_split) { |
215 | | - seq.resize(1); |
216 | | - llama_sbatch_seq & s = seq[0]; |
217 | | - s.n_seq_id = 0; |
218 | | - s.seq_id = nullptr; |
219 | | - s.offset = 0; |
220 | | - s.length = n_tokens; |
221 | | - return; |
222 | | - } |
223 | | - |
224 | | - std::sort(ids.begin(), ids.end(), |
225 | | - [&batch](size_t a, size_t b) { |
226 | | - int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; |
227 | | - int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1; |
228 | | - // sort by seq_id, then by pos |
229 | | - if (n_seq_a == n_seq_b) { |
230 | | - if (batch.seq_id) { |
231 | | - for (int32_t i = 0; i < n_seq_a; ++i) { |
232 | | - llama_seq_id seq_id_a = batch.seq_id[a][i]; |
233 | | - llama_seq_id seq_id_b = batch.seq_id[b][i]; |
234 | | - // smaller seq_ids go first |
235 | | - if (seq_id_a != seq_id_b) { |
236 | | - return seq_id_a < seq_id_b; |
237 | | - } |
238 | | - } |
239 | | - } |
240 | | - // when all else is equal, sort by pos |
241 | | - if (batch.pos) { |
242 | | - return batch.pos[a] < batch.pos[b]; |
243 | | - } |
244 | | - // no pos, sort by id |
245 | | - return a < b; |
246 | | - } |
247 | | - // shared prompts go first |
248 | | - return n_seq_a > n_seq_b; |
249 | | - } |
250 | | - ); |
251 | | - |
252 | | - // init seq |
253 | | - llama_sbatch_seq * last_seq = nullptr; |
254 | | - |
255 | | - for (size_t i = 0; i < n_tokens; ++i) { |
256 | | - const size_t bi = ids[i]; |
257 | | - const int32_t n_seqs = batch.n_seq_id[bi]; |
258 | | - llama_seq_id * seq_ids = batch.seq_id[bi]; |
259 | | - if (last_seq != nullptr) { |
260 | | - bool same = n_seqs == last_seq->n_seq_id; |
261 | | - for (int32_t j = 0; same && j < n_seqs; ++j) { |
262 | | - if (seq_ids[j] != last_seq->seq_id[j]) { |
263 | | - same = false; |
264 | | - } |
265 | | - } |
266 | | - if (same) { |
267 | | - last_seq->length += 1; |
268 | | - continue; |
269 | | - } |
270 | | - } |
271 | | - llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1}; |
272 | | - seq.push_back(new_seq); |
273 | | - last_seq = &seq.back(); |
274 | | - } |
275 | | - |
276 | | - // keep shared prompts first at the end, then sort by length descending. |
277 | | - std::sort(seq.begin(), seq.end(), |
278 | | - [](llama_sbatch_seq & a, llama_sbatch_seq & b) { |
279 | | - if (a.n_seq_id == b.n_seq_id) { |
280 | | - return a.length > b.length; |
281 | | - } |
282 | | - return a.n_seq_id < b.n_seq_id; |
283 | | - } |
284 | | - ); |
285 | | -} |
286 | | - |
287 | 12 | llama_batch_allocr::llama_batch_allocr() { |
288 | 13 | const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG"); |
289 | 14 | debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0; |
|
0 commit comments