server : vision support via libmtmd (#12898)
* server : (experimental) vision support via libmtmd * mtmd : add more api around mtmd_image_tokens * mtmd : add more api around mtmd_image_tokens * mtmd : ability to calc image hash * shared_ptr for mtmd_image_tokens * move hash to user-define ID (fixed) * abstract out the batch management * small fix * refactor logic adding tokens to batch * implement hashing image * use FNV hash, now hash bitmap instead of file data * allow decoding image embedding to be split into batches * rm whitespace * disable some features when mtmd is on * fix --no-mmproj-offload * mtmd_context_params no timings * refactor server_inp to server_tokens * fix the failing test case * init * wip * working version * add mtmd::bitmaps * add test target * rm redundant define * test: mtmd_input_chunks_free * rm outdated comment * fix merging issue * explicitly create mtmd::input_chunks * mtmd_input_chunk_copy * add clone() * improve server_input struct * clip : fix confused naming ffn_up and ffn_down * rm ffn_i/o/g naming * rename n_embd, n_ff * small fix * no check n_ff * fix detokenize * add const to various places * add warning about breaking changes * add c api * helper: use mtmd_image_tokens_get_n_pos * fix ctx_shift * fix name shadowing * more strict condition * support remote image_url * remote image_url log * add CI test * do not log base64 * add "has_multimodal" to /props * remove dangling image * speculative: use slot.cache_tokens.insert * Apply suggestions from code review Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * rm can_be_detokenized * on prmpt processing done, assert cache_tokens.size * handle_completions_impl returns void * adapt the new web ui * update docs and hot topics * rm assert * small fix (2) --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
17512a94d6
commit
33eff40240
10 changed files with 774 additions and 101 deletions
59
tools/server/tests/unit/test_vision_api.py
Normal file
59
tools/server/tests/unit/test_vision_api.py
Normal file
|
@ -0,0 +1,59 @@
|
|||
import pytest
|
||||
from utils import *
|
||||
import base64
|
||||
import requests
|
||||
|
||||
server: ServerProcess
|
||||
|
||||
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
|
||||
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
|
||||
|
||||
response = requests.get(IMG_URL_0)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
IMG_BASE64_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def create_server():
|
||||
global server
|
||||
server = ServerPreset.tinygemma3()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"prompt, image_url, success, re_content",
|
||||
[
|
||||
# test model is trained on CIFAR-10, but it's quite dumb due to small size
|
||||
("What is this:\n", IMG_URL_0, True, "(cat)+"),
|
||||
("What is this:\n", "IMG_BASE64_0", True, "(cat)+"), # exceptional, so that we don't cog up the log
|
||||
("What is this:\n", IMG_URL_1, True, "(frog)+"),
|
||||
("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache
|
||||
("What is this:\n", "malformed", False, None),
|
||||
("What is this:\n", "https://google.com/404", False, None), # non-existent image
|
||||
("What is this:\n", "https://ggml.ai", False, None), # non-image data
|
||||
]
|
||||
)
|
||||
def test_vision_chat_completion(prompt, image_url, success, re_content):
|
||||
global server
|
||||
server.start(timeout_seconds=60) # vision model may take longer to load due to download size
|
||||
if image_url == "IMG_BASE64_0":
|
||||
image_url = IMG_BASE64_0
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
"temperature": 0.0,
|
||||
"top_k": 1,
|
||||
"messages": [
|
||||
{"role": "user", "content": [
|
||||
{"type": "text", "text": prompt},
|
||||
{"type": "image_url", "image_url": {
|
||||
"url": image_url,
|
||||
}},
|
||||
]},
|
||||
],
|
||||
})
|
||||
if success:
|
||||
assert res.status_code == 200
|
||||
choice = res.body["choices"][0]
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert match_regex(re_content, choice["message"]["content"])
|
||||
else:
|
||||
assert res.status_code != 200
|
||||
|
|
@ -88,6 +88,7 @@ class ServerProcess:
|
|||
chat_template: str | None = None
|
||||
chat_template_file: str | None = None
|
||||
server_path: str | None = None
|
||||
mmproj_url: str | None = None
|
||||
|
||||
# session variables
|
||||
process: subprocess.Popen | None = None
|
||||
|
@ -194,6 +195,8 @@ class ServerProcess:
|
|||
server_args.extend(["--chat-template", self.chat_template])
|
||||
if self.chat_template_file:
|
||||
server_args.extend(["--chat-template-file", self.chat_template_file])
|
||||
if self.mmproj_url:
|
||||
server_args.extend(["--mmproj-url", self.mmproj_url])
|
||||
|
||||
args = [str(arg) for arg in [server_path, *server_args]]
|
||||
print(f"tests: starting server with: {' '.join(args)}")
|
||||
|
@ -379,6 +382,21 @@ class ServerPreset:
|
|||
server.server_reranking = True
|
||||
return server
|
||||
|
||||
@staticmethod
|
||||
def tinygemma3() -> ServerProcess:
|
||||
server = ServerProcess()
|
||||
# mmproj is already provided by HF registry API
|
||||
server.model_hf_repo = "ggml-org/tinygemma3-GGUF"
|
||||
server.model_hf_file = "tinygemma3-Q8_0.gguf"
|
||||
server.mmproj_url = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/mmproj-tinygemma3.gguf"
|
||||
server.model_alias = "tinygemma3"
|
||||
server.n_ctx = 1024
|
||||
server.n_batch = 32
|
||||
server.n_slots = 2
|
||||
server.n_predict = 4
|
||||
server.seed = 42
|
||||
return server
|
||||
|
||||
|
||||
def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
|
||||
"""
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue