Skip to content

Commit cfc0dff

Browse files
Lingitee-org
authored andcommitted
!1087 tokenizer解码index越界提示
Merge pull request !1087 from zyw_hw/tokenizer_index
2 parents 81c48ca + 781fd8f commit cfc0dff

File tree

3 files changed

+15
-30
lines changed

3 files changed

+15
-30
lines changed

mindformers/models/base_tokenizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4238,6 +4238,9 @@ def convert_ids_to_tokens(
42384238
if isinstance(ids, int):
42394239
if ids in self.added_tokens_decoder:
42404240
return self.added_tokens_decoder[ids]
4241+
if ids >= self.vocab_size:
4242+
raise IndexError(f"The token id {ids} is out of the size of vocabulary, please check your tokenizer "
4243+
f"and corresponding vocabulary files.")
42414244
return self._convert_id_to_token(ids)
42424245
tokens = []
42434246
for index in ids:
@@ -4247,6 +4250,10 @@ def convert_ids_to_tokens(
42474250
if index in self.added_tokens_decoder:
42484251
tokens.append(self.added_tokens_decoder[index])
42494252
else:
4253+
if index >= self.vocab_size:
4254+
raise IndexError(
4255+
f"The token id {index} is out of the size of vocabulary, please check your tokenizer "
4256+
f"and corresponding vocabulary files.")
42504257
tokens.append(self._convert_id_to_token(index))
42514258
return tokens
42524259

mindformers/models/glm/chatglm_6b_tokenizer.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,10 @@ def _decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_sp
302302
token_ids = [token_ids]
303303
if self.pad_token_id in token_ids: # remove pad
304304
token_ids = list(filter(self.pad_token_id.__ne__, token_ids))
305+
for token_id in token_ids:
306+
if token_id not in self.added_tokens_decoder and token_id >= self.vocab_size:
307+
raise IndexError(f"The token id {token_id} is out of the size of vocabulary, please check "
308+
f"your tokenizer and corresponding vocabulary files.")
305309
return self.sp_tokenizer.decode(token_ids)
306310

307311
# pylint:disable=arguments-differ
@@ -358,36 +362,6 @@ def _convert_token_to_id(self, token):
358362
return self.added_tokens_encoder[token]
359363
return self.sp_tokenizer[token]
360364

361-
# pylint:disable=arguments-differ
362-
def convert_ids_to_tokens(self, ids: Union[int, List[int]], skip_special_tokens: bool = False):
363-
"""
364-
Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and
365-
added tokens.
366-
367-
Args:
368-
ids (`int` or `List[int]`):
369-
The token id (or token ids) to convert to tokens.
370-
skip_special_tokens (`bool`, *optional*, defaults to `False`):
371-
Whether or not to remove special tokens in the decoding.
372-
373-
Returns:
374-
`str` or `List[str]`: The decoded token(s).
375-
"""
376-
if isinstance(ids, int):
377-
if ids in self.added_tokens_decoder:
378-
return self.added_tokens_decoder[ids]
379-
return self._convert_id_to_token(ids)
380-
tokens = []
381-
for index in ids:
382-
index = int(index)
383-
if skip_special_tokens and index in self.all_special_ids:
384-
continue
385-
if index in self.added_tokens_decoder:
386-
tokens.append(self.added_tokens_decoder[index])
387-
else:
388-
tokens.append(self._convert_id_to_token(index))
389-
return tokens
390-
391365
def _convert_id_to_token(self, index):
392366
"""Converts an index (integer) in a token (str) using the vocab."""
393367
return self.sp_tokenizer[index]

mindformers/models/glm2/glm2_tokenizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ def _decode(self,
218218

219219
def _convert_id_to_token(self, index):
220220
"""Converts an index (integer) in a token (str) using the vocab."""
221+
if index >= self.vocab_size:
222+
raise IndexError(
223+
f"The token id {index} is out of the size of vocabulary, please check your tokenizer "
224+
f"and corresponding vocabulary files.")
221225
return self.tokenizer.convert_id_to_token(index)
222226

223227
def convert_tokens_to_string(self, tokens: List[str]) -> str:

0 commit comments

Comments
 (0)