llama : improve sep token handling (#14272)
This commit is contained in:
parent
e28c1b93fd
commit
88fc854b4b
15 changed files with 161 additions and 29 deletions
|
@ -198,6 +198,7 @@ class Keys:
|
|||
MASK_ID = "tokenizer.ggml.mask_token_id"
|
||||
ADD_BOS = "tokenizer.ggml.add_bos_token"
|
||||
ADD_EOS = "tokenizer.ggml.add_eos_token"
|
||||
ADD_SEP = "tokenizer.ggml.add_sep_token"
|
||||
ADD_PREFIX = "tokenizer.ggml.add_space_prefix"
|
||||
REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces"
|
||||
PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap"
|
||||
|
|
|
@ -891,6 +891,9 @@ class GGUFWriter:
|
|||
def add_add_eos_token(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.ADD_EOS, value)
|
||||
|
||||
def add_add_sep_token(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.ADD_SEP, value)
|
||||
|
||||
def add_add_space_prefix(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Tokenizer.ADD_PREFIX, value)
|
||||
|
||||
|
|
|
@ -119,6 +119,7 @@ class SpecialVocab:
|
|||
logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping')
|
||||
|
||||
def _try_load_from_tokenizer_json(self, path: Path) -> bool:
|
||||
tokenizer = None
|
||||
tokenizer_file = path / 'tokenizer.json'
|
||||
if tokenizer_file.is_file():
|
||||
with open(tokenizer_file, encoding = 'utf-8') as f:
|
||||
|
@ -152,11 +153,87 @@ class SpecialVocab:
|
|||
added_tokens = tokenizer.get('added_tokens', {})
|
||||
else:
|
||||
added_tokens = {}
|
||||
tokenizer_config = None
|
||||
tokenizer_config_file = path / 'tokenizer_config.json'
|
||||
if not tokenizer_config_file.is_file():
|
||||
if tokenizer_config_file.is_file():
|
||||
with open(tokenizer_config_file, encoding = 'utf-8') as f:
|
||||
tokenizer_config = json.load(f)
|
||||
if tokenizer:
|
||||
special_bos = (tokenizer_config or {}).get('bos_token')
|
||||
special_cls = (tokenizer_config or {}).get('cls_token')
|
||||
special_eos = (tokenizer_config or {}).get('eos_token')
|
||||
special_sep = (tokenizer_config or {}).get('sep_token')
|
||||
if not special_bos and special_cls and tokenizer_config:
|
||||
tokenizer_config['bos_token'] = special_bos = special_cls
|
||||
if not special_eos and special_sep and tokenizer_config:
|
||||
tokenizer_config['eos_token'] = special_eos = special_sep
|
||||
post_processor = tokenizer.get('post_processor', {})
|
||||
for processor in post_processor.get('processors', [post_processor]):
|
||||
if processor.get('type') == 'RobertaProcessing':
|
||||
self.add_special_token['bos'] = True
|
||||
self.add_special_token['eos'] = True
|
||||
self.add_special_token['sep'] = True
|
||||
if not special_cls and tokenizer_config:
|
||||
special_cls = processor.get('cls', [special_bos])[0]
|
||||
tokenizer_config['cls_token'] = special_cls
|
||||
if not special_sep and tokenizer_config:
|
||||
special_sep = processor.get('sep', [special_eos])[0]
|
||||
tokenizer_config['sep_token'] = special_sep
|
||||
continue
|
||||
# Crude parsing of TemplateProcessing to determine if BOS/SEP/EOS should be added
|
||||
# Only works with simple templates, **will** get it wrong on unusual sequences
|
||||
if processor.get('type') == 'TemplateProcessing':
|
||||
tmpl_single = processor.get('single', [])
|
||||
tmpl_pair = processor.get('pair', [])
|
||||
special_first = None
|
||||
special_last = None
|
||||
if len(tmpl_single) > 1:
|
||||
if special_first := tmpl_single[0].get('SpecialToken', {}).get('id'):
|
||||
if not tokenizer_config:
|
||||
special_bos = special_first
|
||||
self.add_special_token['bos'] = True if special_first in (special_bos, special_cls) else False
|
||||
if special_first not in (special_bos, special_cls):
|
||||
logger.warning(f'Unknown leading special token {special_first!r} in TemplateProcessing<single>')
|
||||
if special_last := tmpl_single[-1].get('SpecialToken', {}).get('id'):
|
||||
if not tokenizer_config:
|
||||
special_eos = special_last
|
||||
self.add_special_token['eos'] = True if special_last == special_eos else False
|
||||
if special_last != special_eos:
|
||||
logger.warning(f'Unknown trailing special token {special_last!r} in TemplateProcessing<single>')
|
||||
if tmpl_pair:
|
||||
seq_start = 1 if tmpl_pair[0].get('SpecialToken', {}).get('id') == special_first else 0
|
||||
seq_stop = -1 if tmpl_pair[-1].get('SpecialToken', {}).get('id') == special_last else None
|
||||
if seq_start == 0 or seq_stop is None:
|
||||
logger.warning('TemplateProcessing<single> leading/trailing special tokens do not match TemplateProcessing<pair>')
|
||||
if tmpl_pair := tmpl_pair[slice(seq_start, seq_stop)]:
|
||||
tmpl_a = tmpl_pair[0].get('Sequence', {}).get('id')
|
||||
tmpl_b = tmpl_pair[-1].get('Sequence', {}).get('id')
|
||||
if tmpl_a != 'A' or tmpl_b != 'B':
|
||||
logger.warning(f'Unknown sequence {tmpl_a}...{tmpl_b} in TemplateProcessing<pair>')
|
||||
# A [sep] [eos] B
|
||||
if tmpl_a == 'A' and tmpl_b == 'B' and (tmpl_pair := tmpl_pair[1:-1]):
|
||||
add_sep = False
|
||||
if special_entry := tmpl_pair[0].get('SpecialToken', {}).get('id'):
|
||||
if special_entry in (special_sep, special_eos) and not special_last:
|
||||
add_sep = True
|
||||
if special_entry not in (special_sep, special_eos):
|
||||
logger.warning(f'Unknown separator token {special_entry!r} in TemplateProcessing<pair>')
|
||||
else:
|
||||
logger.warning(f'Unknown middle sequence {tmpl_pair[0]!r} in TemplateProcessing<pair>')
|
||||
if len(tmpl_pair) == 2:
|
||||
if special_entry := tmpl_pair[1].get('SpecialToken', {}).get('id'):
|
||||
if special_entry in (special_sep, special_eos):
|
||||
add_sep = True
|
||||
if special_entry not in (special_sep, special_eos):
|
||||
logger.warning(f'Unknown second separator token {special_entry!r} in TemplateProcessing<pair>')
|
||||
else:
|
||||
logger.warning(f'Unknown second middle sequence {tmpl_pair[1]!r} in TemplateProcessing<pair>')
|
||||
self.add_special_token['sep'] = add_sep
|
||||
if add_sep and not special_sep and tokenizer_config:
|
||||
tokenizer_config['sep_token'] = special_eos
|
||||
continue
|
||||
if not tokenizer_config:
|
||||
return True
|
||||
with open(tokenizer_config_file, encoding = 'utf-8') as f:
|
||||
tokenizer_config = json.load(f)
|
||||
chat_template_alt = None
|
||||
chat_template_file = path / 'chat_template.json'
|
||||
if chat_template_file.is_file():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue