Skip to content

Support for incremental decoding #1141

@njhill

Description

@njhill

I would like to be able to decode a sequence of token ids incrementally in a decoder-agnostic manner. I haven't found a straightforward way to do this with the current API - the first token is treated differently by some decoders which means that in general

decode([1,2,3]) != decode([1]) + decode([2]) + decode([3])

It would be really nice to have some kind of "continuation" flag to indicate that the result is intended to be be appended to an already-decoded prefix. So that you could have

decode([1,2,3]) == decode([1]) + decode'([2]) + decode'([3])

It would also be nice to have a variant of this that takes either a single u32 id or string token rather than a vec, for related reasons (latter could be used with id_to_token).

I'd love to know if there is another way to achieve this than my current ugly workaround :)

Current workaround
pub(crate) struct Decoder {
    pub(crate) tokenizer: Tokenizer,
    prefix_id: u32,
    prefix: String,
}

impl Decoder {
    pub(crate) fn new(tokenizer: Tokenizer) -> Decoder {
        let prefix_id = tokenizer.token_to_id("A").unwrap();
        Decoder {
            prefix_id,
            prefix: tokenizer.decode(vec![prefix_id], false).unwrap(),
            tokenizer,
        }
    }

    /// Decode continuation tokens to be added to some existing text
    pub(crate) fn decode_continuation(&self, mut ids: Vec<u32>) -> tokenizers::Result<String> {
        // How we handle this depends on the specific decoder's behaviour,
        // see each one's implementation of decode_chain in the tokenizers library.
        match self.tokenizer.get_decoder() {
            Some(ByteLevel(_)) => {
                // Lossless - call standard decode function
                self.tokenizer.decode(ids, true)
            },
            Some(Metaspace(_)) | Some(WordPiece(_)) | Some(BPE(_)) => {
                // For these, the first token in the sequence is treated differently,
                // so we add and then strip a placeholder token.
                ids.insert(0, self.prefix_id);
                let result = self.tokenizer.decode(ids, true)?;
                Ok(result.strip_prefix(&self.prefix).ok_or(DecodingError)?.to_string())
            },
            None => {
                // Just prepend a space
                Ok(format!(" {}", self.tokenizer.decode(ids, true)?))
            },
            _ => Err(UnsupportedTokenizerError.into())
        }
    }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions