llama : add support for BertForSequenceClassification reranker (#13858)
* convert: add support for BertForSequenceClassification * add support for reranking using BertForSequenceClassification * merge checks of eos and sep * fix lint --------- Co-authored-by: dinhhuy <huy.dinh@brains-tech.co.jp>
This commit is contained in:
parent
aa6dff05be
commit
e0e3aa231d
4 changed files with 42 additions and 21 deletions
|
@ -903,13 +903,16 @@ struct common_init_result common_init_from_params(common_params & params) {
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
|
bool has_eos = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
|
||||||
LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__);
|
bool has_sep = llama_vocab_sep(vocab) != LLAMA_TOKEN_NULL;
|
||||||
ok = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) {
|
if (!has_eos && !has_sep) {
|
||||||
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
|
LOG_WRN("%s: warning: vocab does not have an EOS token or SEP token, reranking will not work\n", __func__);
|
||||||
|
ok = false;
|
||||||
|
} else if (!has_eos) {
|
||||||
|
LOG_WRN("%s: warning: vocab does not have an EOS token, using SEP token as fallback\n", __func__);
|
||||||
|
} else if (!has_sep) {
|
||||||
|
LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
|
||||||
ok = false;
|
ok = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -3682,7 +3682,7 @@ class InternLM3Model(TextModel):
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel")
|
@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel", "BertForSequenceClassification")
|
||||||
class BertModel(TextModel):
|
class BertModel(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.BERT
|
model_arch = gguf.MODEL_ARCH.BERT
|
||||||
|
|
||||||
|
@ -3745,6 +3745,13 @@ class BertModel(TextModel):
|
||||||
if name.startswith("cls.seq_relationship"):
|
if name.startswith("cls.seq_relationship"):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
# For BertForSequenceClassification (direct projection layer)
|
||||||
|
if name == "classifier.weight":
|
||||||
|
name = "classifier.out_proj.weight"
|
||||||
|
|
||||||
|
if name == "classifier.bias":
|
||||||
|
name = "classifier.out_proj.bias"
|
||||||
|
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
return [(self.map_tensor_name(name), data_torch)]
|
||||||
|
|
||||||
def _xlmroberta_tokenizer_init(self) -> None:
|
def _xlmroberta_tokenizer_init(self) -> None:
|
||||||
|
|
|
@ -1562,20 +1562,25 @@ void llm_graph_context::build_pooling(
|
||||||
ggml_tensor * inp_cls = build_inp_cls();
|
ggml_tensor * inp_cls = build_inp_cls();
|
||||||
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
inp = ggml_get_rows(ctx0, inp, inp_cls);
|
||||||
|
|
||||||
// classification head
|
if (cls != nullptr && cls_b != nullptr) {
|
||||||
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
// classification head
|
||||||
GGML_ASSERT(cls != nullptr);
|
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
|
||||||
GGML_ASSERT(cls_b != nullptr);
|
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
|
||||||
|
cur = ggml_tanh(ctx0, cur);
|
||||||
|
|
||||||
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
|
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
||||||
cur = ggml_tanh(ctx0, cur);
|
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
||||||
|
if (cls_out) {
|
||||||
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
|
GGML_ASSERT(cls_out_b != nullptr);
|
||||||
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
|
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
|
||||||
if (cls_out) {
|
}
|
||||||
|
} else if (cls_out) {
|
||||||
|
// Single layer classification head (direct projection)
|
||||||
|
// https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
|
||||||
GGML_ASSERT(cls_out_b != nullptr);
|
GGML_ASSERT(cls_out_b != nullptr);
|
||||||
|
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls_out, inp), cls_out_b);
|
||||||
cur = ggml_add (ctx0, ggml_mul_mat(ctx0, cls_out, cur), cls_out_b);
|
} else {
|
||||||
|
GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -264,13 +264,19 @@ static size_t validate_utf8(const std::string& text) {
|
||||||
static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
|
static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
|
||||||
llama_tokens result;
|
llama_tokens result;
|
||||||
|
|
||||||
|
// Get EOS token - use SEP token as fallback if EOS is not available
|
||||||
|
llama_token eos_token = llama_vocab_eos(vocab);
|
||||||
|
if (eos_token == LLAMA_TOKEN_NULL) {
|
||||||
|
eos_token = llama_vocab_sep(vocab);
|
||||||
|
}
|
||||||
|
|
||||||
result.reserve(doc.size() + query.size() + 4);
|
result.reserve(doc.size() + query.size() + 4);
|
||||||
result.push_back(llama_vocab_bos(vocab));
|
result.push_back(llama_vocab_bos(vocab));
|
||||||
result.insert(result.end(), query.begin(), query.end());
|
result.insert(result.end(), query.begin(), query.end());
|
||||||
result.push_back(llama_vocab_eos(vocab));
|
result.push_back(eos_token);
|
||||||
result.push_back(llama_vocab_sep(vocab));
|
result.push_back(llama_vocab_sep(vocab));
|
||||||
result.insert(result.end(), doc.begin(), doc.end());
|
result.insert(result.end(), doc.begin(), doc.end());
|
||||||
result.push_back(llama_vocab_eos(vocab));
|
result.push_back(eos_token);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue