llama : add RobertaForSequenceClassification reranker support (#13875)

This commit is contained in:
Sigbjørn Skjæret 2025-05-29 08:15:01 +02:00 committed by GitHub
parent 1b8fb8152d
commit 6385b843a8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 24 additions and 8 deletions

View file

@ -3695,6 +3695,10 @@ class BertModel(TextModel):
self.gguf_writer.add_causal_attention(False)
self._try_set_pooling_type()
if cls_out_labels := self.hparams.get("id2label"):
key_name = gguf.Keys.Classifier.OUTPUT_LABELS.format(arch = gguf.MODEL_ARCH_NAMES[self.model_arch])
self.gguf_writer.add_array(key_name, [v for k, v in sorted(cls_out_labels.items())])
def set_vocab(self):
tokens, toktypes, tokpre = self.get_vocab_base()
self.vocab_size = len(tokens)
@ -3745,12 +3749,13 @@ class BertModel(TextModel):
if name.startswith("cls.seq_relationship"):
return []
# For BertForSequenceClassification (direct projection layer)
if name == "classifier.weight":
name = "classifier.out_proj.weight"
if self.hparams.get("id2label"):
# For BertForSequenceClassification (direct projection layer)
if name == "classifier.weight":
name = "classifier.out_proj.weight"
if name == "classifier.bias":
name = "classifier.out_proj.bias"
if name == "classifier.bias":
name = "classifier.out_proj.bias"
return [(self.map_tensor_name(name), data_torch)]
@ -3846,7 +3851,7 @@ class BertModel(TextModel):
self.gguf_writer.add_add_eos_token(True)
@ModelBase.register("RobertaModel")
@ModelBase.register("RobertaModel", "RobertaForSequenceClassification")
class RobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT