llama : add RobertaForSequenceClassification reranker support (#13875)
This commit is contained in:
parent
1b8fb8152d
commit
6385b843a8
6 changed files with 24 additions and 8 deletions
|
@ -3695,6 +3695,10 @@ class BertModel(TextModel):
|
|||
self.gguf_writer.add_causal_attention(False)
|
||||
self._try_set_pooling_type()
|
||||
|
||||
if cls_out_labels := self.hparams.get("id2label"):
|
||||
key_name = gguf.Keys.Classifier.OUTPUT_LABELS.format(arch = gguf.MODEL_ARCH_NAMES[self.model_arch])
|
||||
self.gguf_writer.add_array(key_name, [v for k, v in sorted(cls_out_labels.items())])
|
||||
|
||||
def set_vocab(self):
|
||||
tokens, toktypes, tokpre = self.get_vocab_base()
|
||||
self.vocab_size = len(tokens)
|
||||
|
@ -3745,12 +3749,13 @@ class BertModel(TextModel):
|
|||
if name.startswith("cls.seq_relationship"):
|
||||
return []
|
||||
|
||||
# For BertForSequenceClassification (direct projection layer)
|
||||
if name == "classifier.weight":
|
||||
name = "classifier.out_proj.weight"
|
||||
if self.hparams.get("id2label"):
|
||||
# For BertForSequenceClassification (direct projection layer)
|
||||
if name == "classifier.weight":
|
||||
name = "classifier.out_proj.weight"
|
||||
|
||||
if name == "classifier.bias":
|
||||
name = "classifier.out_proj.bias"
|
||||
if name == "classifier.bias":
|
||||
name = "classifier.out_proj.bias"
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
@ -3846,7 +3851,7 @@ class BertModel(TextModel):
|
|||
self.gguf_writer.add_add_eos_token(True)
|
||||
|
||||
|
||||
@ModelBase.register("RobertaModel")
|
||||
@ModelBase.register("RobertaModel", "RobertaForSequenceClassification")
|
||||
class RobertaModel(BertModel):
|
||||
model_arch = gguf.MODEL_ARCH.BERT
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue