convert : fix vocab padding code for bert models (#13954)

This commit is contained in:
Sigbjørn Skjæret 2025-06-01 17:23:11 +02:00 committed by GitHub
parent e57bb87ced
commit c496fe0b1d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -3814,7 +3814,7 @@ class BertModel(TextModel):
remove_whitespaces = tokenizer.clean_up_tokenization_spaces
precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"])
vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size)
vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size)
else:
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
@ -3827,7 +3827,7 @@ class BertModel(TextModel):
tokenizer = SentencePieceProcessor()
tokenizer.LoadFromFile(str(tokenizer_path))
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size())
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
scores: list[float] = [-10000.0] * vocab_size
@ -3857,33 +3857,26 @@ class BertModel(TextModel):
unk_token = tokenizer_config_json.get("unk_token")
unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3))
for token_id in range(vocab_size):
for token_id in range(tokenizer.vocab_size):
piece = tokenizer._convert_id_to_token(token_id)
text = piece.encode("utf-8")
score = tokenizer_json["model"]["vocab"][token_id][1]
if (piece := tokenizer._convert_id_to_token(token_id)) is not None:
text = piece.encode("utf-8")
score = tokenizer_json["model"]["vocab"][token_id][1]
toktype = SentencePieceTokenTypes.NORMAL
if token_id == unk_token_id:
toktype = SentencePieceTokenTypes.UNKNOWN
elif token_id in tokenizer.all_special_ids:
toktype = SentencePieceTokenTypes.CONTROL
elif token_id in added_vocab.values():
toktype = SentencePieceTokenTypes.USER_DEFINED
# No reliable way to detect this, but jina doesn't have any
# elif tokenizer.IsByte(token_id):
# toktype = SentencePieceTokenTypes.BYTE
toktype = SentencePieceTokenTypes.NORMAL
if token_id == unk_token_id:
toktype = SentencePieceTokenTypes.UNKNOWN
elif token_id in tokenizer.all_special_ids:
toktype = SentencePieceTokenTypes.CONTROL
elif token_id in added_vocab.values():
toktype = SentencePieceTokenTypes.USER_DEFINED
# No reliable way to detect this, but jina doesn't have any
# elif tokenizer.IsByte(token_id):
# toktype = SentencePieceTokenTypes.BYTE
tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype
if vocab_size > len(tokens):
pad_count = vocab_size - len(tokens)
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
for i in range(1, pad_count + 1):
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.UNUSED)
tokens[token_id] = text
scores[token_id] = score
toktypes[token_id] = toktype
if isinstance(tokenizer, SentencePieceProcessor):
# realign tokens (see HF tokenizer code)