-
Notifications
You must be signed in to change notification settings - Fork 975
Closed
Description
I would like to be able to decode a sequence of token ids incrementally in a decoder-agnostic manner. I haven't found a straightforward way to do this with the current API - the first token is treated differently by some decoders which means that in general
decode([1,2,3]) != decode([1]) + decode([2]) + decode([3])
It would be really nice to have some kind of "continuation" flag to indicate that the result is intended to be be appended to an already-decoded prefix. So that you could have
decode([1,2,3]) == decode([1]) + decode'([2]) + decode'([3])
It would also be nice to have a variant of this that takes either a single u32
id or string token rather than a vec, for related reasons (latter could be used with id_to_token
).
I'd love to know if there is another way to achieve this than my current ugly workaround :)
Current workaround
pub(crate) struct Decoder {
pub(crate) tokenizer: Tokenizer,
prefix_id: u32,
prefix: String,
}
impl Decoder {
pub(crate) fn new(tokenizer: Tokenizer) -> Decoder {
let prefix_id = tokenizer.token_to_id("A").unwrap();
Decoder {
prefix_id,
prefix: tokenizer.decode(vec![prefix_id], false).unwrap(),
tokenizer,
}
}
/// Decode continuation tokens to be added to some existing text
pub(crate) fn decode_continuation(&self, mut ids: Vec<u32>) -> tokenizers::Result<String> {
// How we handle this depends on the specific decoder's behaviour,
// see each one's implementation of decode_chain in the tokenizers library.
match self.tokenizer.get_decoder() {
Some(ByteLevel(_)) => {
// Lossless - call standard decode function
self.tokenizer.decode(ids, true)
},
Some(Metaspace(_)) | Some(WordPiece(_)) | Some(BPE(_)) => {
// For these, the first token in the sequence is treated differently,
// so we add and then strip a placeholder token.
ids.insert(0, self.prefix_id);
let result = self.tokenizer.decode(ids, true)?;
Ok(result.strip_prefix(&self.prefix).ok_or(DecodingError)?.to_string())
},
None => {
// Just prepend a space
Ok(format!(" {}", self.tokenizer.decode(ids, true)?))
},
_ => Err(UnsupportedTokenizerError.into())
}
}
}
Metadata
Metadata
Assignees
Labels
No labels