@@ -1419,7 +1419,7 @@ struct server_context {
1419
1419
queue_results.send (res);
1420
1420
}
1421
1421
1422
- void send_rank (const server_slot & slot, const llama_batch & batch) {
1422
+ void send_rerank (const server_slot & slot, const llama_batch & batch) {
1423
1423
server_task_result res;
1424
1424
res.id = slot.id_task ;
1425
1425
res.error = false ;
@@ -1440,19 +1440,19 @@ struct server_context {
1440
1440
1441
1441
res.data = json {
1442
1442
{" index" , slot.index },
1443
- {" rank " , -1e6 },
1443
+ {" score " , -1e6 },
1444
1444
};
1445
1445
1446
1446
continue ;
1447
1447
}
1448
1448
1449
1449
res.data = json {
1450
1450
{" index" , slot.index },
1451
- {" rank " , embd[0 ]},
1451
+ {" score " , embd[0 ]},
1452
1452
};
1453
1453
}
1454
1454
1455
- SLT_DBG (slot, " sending rank , res = '%s'\n " , res.data .dump ().c_str ());
1455
+ SLT_DBG (slot, " sending rerank result , res = '%s'\n " , res.data .dump ().c_str ());
1456
1456
1457
1457
queue_results.send (res);
1458
1458
}
@@ -1493,6 +1493,9 @@ struct server_context {
1493
1493
else if (prompt.is_array ()) {
1494
1494
std::vector<json> prompts = prompt;
1495
1495
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
1496
+ // prompts[0] is the question
1497
+ // the rest are the answers/documents
1498
+ SRV_DBG (" creating rerank tasks, n_prompts = %d\n " , (int ) prompts.size () - 1 );
1496
1499
for (size_t i = 1 ; i < prompts.size (); i++) {
1497
1500
json qd;
1498
1501
qd.push_back (prompts[0 ]);
@@ -1501,6 +1504,7 @@ struct server_context {
1501
1504
create_task (data, true , qd);
1502
1505
}
1503
1506
} else {
1507
+ SRV_DBG (" creating multi-prompt tasks, n_prompts = %d\n " , (int ) prompts.size ());
1504
1508
for (size_t i = 0 ; i < prompts.size (); i++) {
1505
1509
const auto & e = prompts[i];
1506
1510
if (e.is_string () || json_is_array_of_numbers (e)) {
@@ -1965,6 +1969,7 @@ struct server_context {
1965
1969
// track if this is an embedding or non-embedding batch
1966
1970
// if we've added sampled tokens above, we are in non-embedding mode
1967
1971
// -1: none, 0: non-embedding, 1: embedding
1972
+ // TODO: make enum
1968
1973
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1 ;
1969
1974
1970
1975
// next, batch any pending prompts without exceeding n_batch
@@ -2133,6 +2138,7 @@ struct server_context {
2133
2138
slot.n_prompt_tokens_processed = 0 ;
2134
2139
}
2135
2140
2141
+ // non-causal tasks require to fit the entire prompt in the physical batch
2136
2142
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2137
2143
// cannot fit the prompt in the current batch - will try next iter
2138
2144
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
@@ -2318,7 +2324,7 @@ struct server_context {
2318
2324
}
2319
2325
2320
2326
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2321
- send_rank (slot, batch_view);
2327
+ send_rerank (slot, batch_view);
2322
2328
slot.release ();
2323
2329
slot.i_batch = -1 ;
2324
2330
continue ; // continue loop of slots
0 commit comments