server : replace behave with pytest (#10416)

* server : replace behave with pytest

* fix test on windows

* misc

* add more tests

* more tests

* styling

* log less, fix embd test

* added all sequential tests

* fix coding style

* fix save slot test

* add parallel completion test

* fix parallel test

* remove feature files

* update test docs

* no cache_prompt for some tests

* add test_cache_vs_nocache_prompt
This commit is contained in:
Xuan Son Nguyen 2024-11-26 16:20:18 +01:00 committed by GitHub
parent 0bbd2262a3
commit 45abe0f74e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 1317 additions and 2497 deletions

View file

@ -0,0 +1,34 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
def test_server_start_simple():
global server
server.start()
res = server.make_request("GET", "/health")
assert res.status_code == 200
def test_server_props():
global server
server.start()
res = server.make_request("GET", "/props")
assert res.status_code == 200
assert res.body["total_slots"] == server.n_slots
def test_server_models():
global server
server.start()
res = server.make_request("GET", "/models")
assert res.status_code == 200
assert len(res.body["data"]) == 1
assert res.body["data"][0]["id"] == server.model_alias

View file

@ -0,0 +1,129 @@
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
[
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
]
)
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"model": model,
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
})
assert res.status_code == 200
assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["completion_tokens"] == n_predicted
choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"])
if truncated:
assert choice["finish_reason"] == "length"
else:
assert choice["finish_reason"] == "stop"
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
[
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
]
)
def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
global server
server.start()
res = server.make_stream_request("POST", "/chat/completions", data={
"model": model,
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"stream": True,
})
content = ""
for data in res:
choice = data["choices"][0]
if choice["finish_reason"] in ["stop", "length"]:
assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted
assert "content" not in choice["delta"]
assert match_regex(re_content, content)
# FIXME: not sure why this is incorrect in stream mode
# if truncated:
# assert choice["finish_reason"] == "length"
# else:
# assert choice["finish_reason"] == "stop"
else:
assert choice["finish_reason"] is None
content += choice["delta"]["content"]
def test_chat_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=8,
seed=42,
temperature=0.8,
)
print(res)
assert res.choices[0].finish_reason == "stop"
assert res.choices[0].message.content is not None
assert match_regex("(Suddenly)+", res.choices[0].message.content)
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
({"type": "json_object"}, 10, "(\\{|John)+"),
({"type": "sound"}, 0, None),
# invalid response format (expected to fail)
({"type": "json_object", "schema": 123}, 0, None),
({"type": "json_object", "schema": {"type": 123}}, 0, None),
({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
])
def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"response_format": response_format,
})
if re_content is not None:
assert res.status_code == 200
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"])
else:
assert res.status_code != 200
assert "error" in res.body

View file

@ -0,0 +1,223 @@
import pytest
import time
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
])
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": prompt,
})
assert res.status_code == 200
assert res.body["timings"]["prompt_n"] == n_prompt
assert res.body["timings"]["predicted_n"] == n_predicted
assert res.body["truncated"] == truncated
assert match_regex(re_content, res.body["content"])
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
])
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
global server
server.start()
res = server.make_stream_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": prompt,
"stream": True,
})
content = ""
for data in res:
if data["stop"]:
assert data["timings"]["prompt_n"] == n_prompt
assert data["timings"]["predicted_n"] == n_predicted
assert data["truncated"] == truncated
assert match_regex(re_content, content)
else:
content += data["content"]
@pytest.mark.parametrize("n_slots", [1, 2])
def test_consistent_result_same_seed(n_slots: int):
global server
server.n_slots = n_slots
server.start()
last_res = None
for _ in range(4):
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 1.0,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
if last_res is not None:
assert res.body["content"] == last_res.body["content"]
last_res = res
@pytest.mark.parametrize("n_slots", [1, 2])
def test_different_result_different_seed(n_slots: int):
global server
server.n_slots = n_slots
server.start()
last_res = None
for seed in range(4):
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": seed,
"temperature": 1.0,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
if last_res is not None:
assert res.body["content"] != last_res.body["content"]
last_res = res
@pytest.mark.parametrize("n_batch", [16, 32])
@pytest.mark.parametrize("temperature", [0.0, 1.0])
def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
global server
server.n_batch = n_batch
server.start()
last_res = None
for _ in range(4):
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": temperature,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
if last_res is not None:
assert res.body["content"] == last_res.body["content"]
last_res = res
@pytest.mark.skip(reason="This test fails on linux, need to be fixed")
def test_cache_vs_nocache_prompt():
global server
server.start()
res_cache = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 1.0,
"cache_prompt": True,
})
res_no_cache = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 1.0,
"cache_prompt": False,
})
assert res_cache.body["content"] == res_no_cache.body["content"]
def test_completion_with_tokens_input():
global server
server.temperature = 0.0
server.start()
prompt_str = "I believe the meaning of life is"
res = server.make_request("POST", "/tokenize", data={
"content": prompt_str,
"add_special": True,
})
assert res.status_code == 200
tokens = res.body["tokens"]
# single completion
res = server.make_request("POST", "/completion", data={
"prompt": tokens,
})
assert res.status_code == 200
assert type(res.body["content"]) == str
# batch completion
res = server.make_request("POST", "/completion", data={
"prompt": [tokens, tokens],
})
assert res.status_code == 200
assert type(res.body) == list
assert len(res.body) == 2
assert res.body[0]["content"] == res.body[1]["content"]
# mixed string and tokens
res = server.make_request("POST", "/completion", data={
"prompt": [tokens, prompt_str],
})
assert res.status_code == 200
assert type(res.body) == list
assert len(res.body) == 2
assert res.body[0]["content"] == res.body[1]["content"]
# mixed string and tokens in one sequence
res = server.make_request("POST", "/completion", data={
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
})
assert res.status_code == 200
assert type(res.body["content"]) == str
@pytest.mark.parametrize("n_slots,n_requests", [
(1, 3),
(2, 2),
(2, 4),
(4, 2), # some slots must be idle
(4, 6),
])
def test_completion_parallel_slots(n_slots: int, n_requests: int):
global server
server.n_slots = n_slots
server.temperature = 0.0
server.start()
PROMPTS = [
("Write a very long book.", "(very|special|big)+"),
("Write another a poem.", "(small|house)+"),
("What is LLM?", "(Dad|said)+"),
("The sky is blue and I love it.", "(climb|leaf)+"),
("Write another very long music lyrics.", "(friends|step|sky)+"),
("Write a very long joke.", "(cat|Whiskers)+"),
]
def check_slots_status():
should_all_slots_busy = n_requests >= n_slots
time.sleep(0.1)
res = server.make_request("GET", "/slots")
n_busy = sum([1 for slot in res.body if slot["is_processing"]])
if should_all_slots_busy:
assert n_busy == n_slots
else:
assert n_busy <= n_slots
tasks = []
for i in range(n_requests):
prompt, re_content = PROMPTS[i % len(PROMPTS)]
tasks.append((server.make_request, ("POST", "/completion", {
"prompt": prompt,
"seed": 42,
"temperature": 1.0,
})))
tasks.append((check_slots_status, ()))
results = parallel_function_calls(tasks)
# check results
for i in range(n_requests):
prompt, re_content = PROMPTS[i % len(PROMPTS)]
res = results[i]
assert res.status_code == 200
assert type(res.body["content"]) == str
assert len(res.body["content"]) > 10
# FIXME: the result is not deterministic when using other slot than slot 0
# assert match_regex(re_content, res.body["content"])

View file

@ -0,0 +1,67 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
LONG_TEXT = """
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
""".strip()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.n_ctx = 256
server.n_slots = 2
def test_ctx_shift_enabled():
# the prompt is 301 tokens
# the slot context is 256/2 = 128 tokens
# the prompt is truncated to keep the last 109 tokens
# 64 tokens are generated thanks to shifting the context when it gets full
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 64,
"prompt": LONG_TEXT,
})
assert res.status_code == 200
assert res.body["timings"]["prompt_n"] == 109
assert res.body["timings"]["predicted_n"] == 64
assert res.body["truncated"] is True
@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
(64, 64, False),
(-1, 120, True),
])
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
global server
server.disable_ctx_shift = True
server.n_predict = -1
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": "Hi how are you",
})
assert res.status_code == 200
assert res.body["timings"]["predicted_n"] == n_token_output
assert res.body["truncated"] == truncated
def test_ctx_shift_disabled_long_prompt():
global server
server.disable_ctx_shift = True
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 64,
"prompt": LONG_TEXT,
})
assert res.status_code != 200
assert "error" in res.body
assert "exceeds the available context size" in res.body["error"]["message"]

View file

@ -0,0 +1,99 @@
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.bert_bge_small()
EPSILON = 1e-3
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.bert_bge_small()
def test_embedding_single():
global server
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": "I believe the meaning of life is",
})
assert res.status_code == 200
assert len(res.body['data']) == 1
assert 'embedding' in res.body['data'][0]
assert len(res.body['data'][0]['embedding']) > 1
# make sure embedding vector is normalized
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
def test_embedding_multiple():
global server
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": [
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
"This is a test",
"This is another test",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 4
for d in res.body['data']:
assert 'embedding' in d
assert len(d['embedding']) > 1
def test_embedding_openai_library_single():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
assert len(res.data) == 1
assert len(res.data[0].embedding) > 1
def test_embedding_openai_library_multiple():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
res = client.embeddings.create(model="text-embedding-3-small", input=[
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
"This is a test",
"This is another test",
])
assert len(res.data) == 4
for d in res.data:
assert len(d.embedding) > 1
def test_embedding_error_prompt_too_long():
global server
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": "This is a test " * 512,
})
assert res.status_code != 200
assert "too large" in res.body["error"]["message"]
def test_same_prompt_give_same_result():
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
"I believe the meaning of life is",
"I believe the meaning of life is",
"I believe the meaning of life is",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 5
for i in range(1, len(res.body['data'])):
v0 = res.body['data'][0]['embedding']
vi = res.body['data'][i]['embedding']
for x, y in zip(v0, vi):
assert abs(x - y) < EPSILON

View file

@ -0,0 +1,35 @@
import pytest
from utils import *
server = ServerPreset.tinyllama_infill()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama_infill()
def test_infill_without_input_extra():
global server
server.start()
res = server.make_request("POST", "/infill", data={
"prompt": "Complete this",
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 200
assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"])
def test_infill_with_input_extra():
global server
server.start()
res = server.make_request("POST", "/infill", data={
"prompt": "Complete this",
"input_extra": [{
"filename": "llama.h",
"text": "LLAMA_API int32_t llama_n_threads();\n"
}],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 200
assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"])

View file

@ -0,0 +1,42 @@
import pytest
import os
from utils import *
server = ServerPreset.stories15m_moe()
LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf"
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.stories15m_moe()
# download lora file if needed
file_name = LORA_FILE_URL.split('/').pop()
lora_file = f'../../../{file_name}'
if not os.path.exists(lora_file):
print(f"Downloading {LORA_FILE_URL} to {lora_file}")
with open(lora_file, 'wb') as f:
f.write(requests.get(LORA_FILE_URL).content)
print(f"Done downloading lora file")
server.lora_files = [lora_file]
@pytest.mark.parametrize("scale,re_content", [
# without applying lora, the model should behave like a bedtime story generator
(0.0, "(little|girl|three|years|old)+"),
# with lora, the model should behave like a Shakespearean text generator
(1.0, "(eye|love|glass|sun)+"),
])
def test_lora(scale: float, re_content: str):
global server
server.start()
res_lora_control = server.make_request("POST", "/lora-adapters", data=[
{"id": 0, "scale": scale}
])
assert res_lora_control.status_code == 200
res = server.make_request("POST", "/completion", data={
"prompt": "Look in thy glass",
})
assert res.status_code == 200
assert match_regex(re_content, res.body["content"])

View file

@ -0,0 +1,38 @@
import pytest
from utils import *
server = ServerPreset.jina_reranker_tiny()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.jina_reranker_tiny()
def test_rerank():
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"documents": [
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
]
})
assert res.status_code == 200
assert len(res.body["results"]) == 4
most_relevant = res.body["results"][0]
least_relevant = res.body["results"][0]
for doc in res.body["results"]:
if doc["relevance_score"] > most_relevant["relevance_score"]:
most_relevant = doc
if doc["relevance_score"] < least_relevant["relevance_score"]:
least_relevant = doc
assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
assert most_relevant["index"] == 2
assert least_relevant["index"] == 3

View file

@ -0,0 +1,83 @@
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.tinyllama2()
TEST_API_KEY = "sk-this-is-the-secret-key"
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.api_key = TEST_API_KEY
@pytest.mark.parametrize("endpoint", ["/health", "/models"])
def test_access_public_endpoint(endpoint: str):
global server
server.start()
res = server.make_request("GET", endpoint)
assert res.status_code == 200
assert "error" not in res.body
@pytest.mark.parametrize("api_key", [None, "invalid-key"])
def test_incorrect_api_key(api_key: str):
global server
server.start()
res = server.make_request("POST", "/completions", data={
"prompt": "I believe the meaning of life is",
}, headers={
"Authorization": f"Bearer {api_key}" if api_key else None,
})
assert res.status_code == 401
assert "error" in res.body
assert res.body["error"]["type"] == "authentication_error"
def test_correct_api_key():
global server
server.start()
res = server.make_request("POST", "/completions", data={
"prompt": "I believe the meaning of life is",
}, headers={
"Authorization": f"Bearer {TEST_API_KEY}",
})
assert res.status_code == 200
assert "error" not in res.body
assert "content" in res.body
def test_openai_library_correct_api_key():
global server
server.start()
client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}")
res = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a chatbot."},
{"role": "user", "content": "What is the meaning of life?"},
],
)
assert len(res.choices) == 1
@pytest.mark.parametrize("origin,cors_header,cors_header_value", [
("localhost", "Access-Control-Allow-Origin", "localhost"),
("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"),
("origin", "Access-Control-Allow-Credentials", "true"),
("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"),
("web.mydomain.fr", "Access-Control-Allow-Headers", "*"),
])
def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
global server
server.start()
res = server.make_request("OPTIONS", "/completions", headers={
"Origin": origin,
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "Authorization",
})
assert res.status_code == 200
assert cors_header in res.headers
assert res.headers[cors_header] == cors_header_value

View file

@ -0,0 +1,98 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.slot_save_path = "./tmp"
server.temperature = 0.0
def test_slot_save_restore():
global server
server.start()
# First prompt in slot 1 should be fully processed
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of France?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Whiskers|Flana)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
# Save state of slot 1
res = server.make_request("POST", "/slots/1?action=save", data={
"filename": "slot1.bin",
})
assert res.status_code == 200
assert res.body["n_saved"] == 84
# Since we have cache, this should only process the last tokens
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of Germany?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Jack|said)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
# Loading the saved cache into slot 0
res = server.make_request("POST", "/slots/0?action=restore", data={
"filename": "slot1.bin",
})
assert res.status_code == 200
assert res.body["n_restored"] == 84
# Since we have cache, slot 0 should only process the last tokens
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of Germany?",
"id_slot": 0,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Jack|said)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
# For verification that slot 1 was not corrupted during slot 0 load, same thing should work
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of Germany?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Jack|said)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 1
def test_slot_erase():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of France?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Whiskers|Flana)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
# erase slot 1
res = server.make_request("POST", "/slots/1?action=erase")
assert res.status_code == 200
# re-run the same prompt, it should process all tokens again
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of France?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Whiskers|Flana)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed

View file

@ -0,0 +1,59 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
def test_tokenize_detokenize():
global server
server.start()
# tokenize
content = "What is the capital of France ?"
res_tok = server.make_request("POST", "/tokenize", data={
"content": content
})
assert res_tok.status_code == 200
assert len(res_tok.body["tokens"]) > 5
# detokenize
res_detok = server.make_request("POST", "/detokenize", data={
"tokens": res_tok.body["tokens"],
})
assert res_detok.status_code == 200
assert res_detok.body["content"].strip() == content
def test_tokenize_with_bos():
global server
server.start()
# tokenize
content = "What is the capital of France ?"
bosId = 1
res_tok = server.make_request("POST", "/tokenize", data={
"content": content,
"add_special": True,
})
assert res_tok.status_code == 200
assert res_tok.body["tokens"][0] == bosId
def test_tokenize_with_pieces():
global server
server.start()
# tokenize
content = "This is a test string with unicode 媽 and emoji 🤗"
res_tok = server.make_request("POST", "/tokenize", data={
"content": content,
"with_pieces": True,
})
assert res_tok.status_code == 200
for token in res_tok.body["tokens"]:
assert "id" in token
assert token["id"] > 0
assert "piece" in token
assert len(token["piece"]) > 0