llama : move end-user examples to tools directory (#13249)

* llama : move end-user examples to tools directory

---------

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
Diego Devesa 2025-05-02 20:27:13 +02:00 committed by GitHub
parent b34443923c
commit 1d36b3670b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
213 changed files with 226 additions and 190 deletions

View file

@ -0,0 +1,50 @@
set(TARGET llama-server)
option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF)
include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR})
if (MINGW)
# fix: https://github.com/ggml-org/llama.cpp/actions/runs/9651004652/job/26617901362?pr=8006
add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER})
endif()
set(TARGET_SRCS
server.cpp
utils.hpp
httplib.h
)
set(PUBLIC_ASSETS
index.html.gz
loading.html
)
foreach(asset ${PUBLIC_ASSETS})
set(input "${CMAKE_CURRENT_SOURCE_DIR}/public/${asset}")
set(output "${CMAKE_CURRENT_BINARY_DIR}/${asset}.hpp")
list(APPEND TARGET_SRCS ${output})
add_custom_command(
DEPENDS "${input}"
OUTPUT "${output}"
COMMAND "${CMAKE_COMMAND}" "-DINPUT=${input}" "-DOUTPUT=${output}" -P "${PROJECT_SOURCE_DIR}/scripts/xxd.cmake"
)
set_source_files_properties(${output} PROPERTIES GENERATED TRUE)
endforeach()
add_executable(${TARGET} ${TARGET_SRCS})
install(TARGETS ${TARGET} RUNTIME)
target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR})
target_link_libraries(${TARGET} PRIVATE common ${CMAKE_THREAD_LIBS_INIT})
if (LLAMA_SERVER_SSL)
find_package(OpenSSL REQUIRED)
target_link_libraries(${TARGET} PRIVATE OpenSSL::SSL OpenSSL::Crypto)
target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_OPENSSL_SUPPORT)
endif()
if (WIN32)
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
endif()
target_compile_features(${TARGET} PRIVATE cxx_std_17)

1267
tools/server/README.md Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,119 @@
### Server benchmark tools
Benchmark is using [k6](https://k6.io/).
##### Install k6 and sse extension
SSE is not supported by default in k6, you have to build k6 with the [xk6-sse](https://github.com/phymbert/xk6-sse) extension.
Example (assuming golang >= 1.21 is installed):
```shell
go install go.k6.io/xk6/cmd/xk6@latest
$GOPATH/bin/xk6 build master \
--with github.com/phymbert/xk6-sse
```
#### Download a dataset
This dataset was originally proposed in [vLLM benchmarks](https://github.com/vllm-project/vllm/blob/main/benchmarks/README.md).
```shell
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
```
#### Download a model
Example for PHI-2
```shell
../../../scripts/hf.sh --repo ggml-org/models --file phi-2/ggml-model-q4_0.gguf
```
#### Start the server
The server must answer OAI Chat completion requests on `http://localhost:8080/v1` or according to the environment variable `SERVER_BENCH_URL`.
Example:
```shell
llama-server --host localhost --port 8080 \
--model ggml-model-q4_0.gguf \
--cont-batching \
--metrics \
--parallel 8 \
--batch-size 512 \
--ctx-size 4096 \
-ngl 33
```
#### Run the benchmark
For 500 chat completions request with 8 concurrent users during maximum 10 minutes, run:
```shell
./k6 run script.js --duration 10m --iterations 500 --vus 8
```
The benchmark values can be overridden with:
- `SERVER_BENCH_URL` server url prefix for chat completions, default `http://localhost:8080/v1`
- `SERVER_BENCH_N_PROMPTS` total prompts to randomly select in the benchmark, default `480`
- `SERVER_BENCH_MODEL_ALIAS` model alias to pass in the completion request, default `my-model`
- `SERVER_BENCH_MAX_TOKENS` max tokens to predict, default: `512`
- `SERVER_BENCH_DATASET` path to the benchmark dataset file
- `SERVER_BENCH_MAX_PROMPT_TOKENS` maximum prompt tokens to filter out in the dataset: default `1024`
- `SERVER_BENCH_MAX_CONTEXT` maximum context size of the completions request to filter out in the dataset: prompt + predicted tokens, default `2048`
Note: the local tokenizer is just a string space split, real number of tokens will differ.
Or with [k6 options](https://k6.io/docs/using-k6/k6-options/reference/):
```shell
SERVER_BENCH_N_PROMPTS=500 k6 run script.js --duration 10m --iterations 500 --vus 8
```
To [debug http request](https://k6.io/docs/using-k6/http-debugging/) use `--http-debug="full"`.
#### Metrics
Following metrics are available computed from the OAI chat completions response `usage`:
- `llamacpp_tokens_second` Trend of `usage.total_tokens / request duration`
- `llamacpp_prompt_tokens` Trend of `usage.prompt_tokens`
- `llamacpp_prompt_tokens_total_counter` Counter of `usage.prompt_tokens`
- `llamacpp_completion_tokens` Trend of `usage.completion_tokens`
- `llamacpp_completion_tokens_total_counter` Counter of `usage.completion_tokens`
- `llamacpp_completions_truncated_rate` Rate of completions truncated, i.e. if `finish_reason === 'length'`
- `llamacpp_completions_stop_rate` Rate of completions stopped by the model, i.e. if `finish_reason === 'stop'`
The script will fail if too many completions are truncated, see `llamacpp_completions_truncated_rate`.
K6 metrics might be compared against [server metrics](../README.md), with:
```shell
curl http://localhost:8080/metrics
```
### Using the CI python script
The `bench.py` script does several steps:
- start the server
- define good variable for k6
- run k6 script
- extract metrics from prometheus
It aims to be used in the CI, but you can run it manually:
```shell
LLAMA_SERVER_BIN_PATH=../../../cmake-build-release/bin/llama-server python bench.py \
--runner-label local \
--name local \
--branch `git rev-parse --abbrev-ref HEAD` \
--commit `git rev-parse HEAD` \
--scenario script.js \
--duration 5m \
--hf-repo ggml-org/models \
--hf-file phi-2/ggml-model-q4_0.gguf \
--model-path-prefix models \
--parallel 4 \
-ngl 33 \
--batch-size 2048 \
--ubatch-size 256 \
--ctx-size 4096 \
--n-prompts 200 \
--max-prompt-tokens 256 \
--max-tokens 256
```

323
tools/server/bench/bench.py Normal file
View file

@ -0,0 +1,323 @@
from __future__ import annotations
import argparse
import json
import os
import re
import signal
import socket
import subprocess
import sys
import threading
import time
import traceback
from contextlib import closing
from datetime import datetime
import matplotlib
import matplotlib.dates
import matplotlib.pyplot as plt
import requests
from statistics import mean
def main(args_in: list[str] | None = None) -> None:
parser = argparse.ArgumentParser(description="Start server benchmark scenario")
parser.add_argument("--name", type=str, help="Bench name", required=True)
parser.add_argument("--runner-label", type=str, help="Runner label", required=True)
parser.add_argument("--branch", type=str, help="Branch name", default="detached")
parser.add_argument("--commit", type=str, help="Commit name", default="dirty")
parser.add_argument("--host", type=str, help="Server listen host", default="0.0.0.0")
parser.add_argument("--port", type=int, help="Server listen host", default="8080")
parser.add_argument("--model-path-prefix", type=str, help="Prefix where to store the model files", default="models")
parser.add_argument("--n-prompts", type=int,
help="SERVER_BENCH_N_PROMPTS: total prompts to randomly select in the benchmark", required=True)
parser.add_argument("--max-prompt-tokens", type=int,
help="SERVER_BENCH_MAX_PROMPT_TOKENS: maximum prompt tokens to filter out in the dataset",
required=True)
parser.add_argument("--max-tokens", type=int,
help="SERVER_BENCH_MAX_CONTEXT: maximum context size of the completions request to filter out in the dataset: prompt + predicted tokens",
required=True)
parser.add_argument("--hf-repo", type=str, help="Hugging Face model repository", required=True)
parser.add_argument("--hf-file", type=str, help="Hugging Face model file", required=True)
parser.add_argument("-ngl", "--n-gpu-layers", type=int, help="layers to the GPU for computation", required=True)
parser.add_argument("--ctx-size", type=int, help="Set the size of the prompt context", required=True)
parser.add_argument("--parallel", type=int, help="Set the number of slots for process requests", required=True)
parser.add_argument("--batch-size", type=int, help="Set the batch size for prompt processing", required=True)
parser.add_argument("--ubatch-size", type=int, help="physical maximum batch size", required=True)
parser.add_argument("--scenario", type=str, help="Scenario to run", required=True)
parser.add_argument("--duration", type=str, help="Bench scenario", required=True)
args = parser.parse_args(args_in)
start_time = time.time()
# Start the server and performance scenario
try:
server_process = start_server(args)
except Exception:
print("bench: server start error :")
traceback.print_exc(file=sys.stdout)
sys.exit(1)
# start the benchmark
iterations = 0
data = {}
try:
start_benchmark(args)
with open("results.github.env", 'w') as github_env:
# parse output
with open('k6-results.json', 'r') as bench_results:
# Load JSON data from file
data = json.load(bench_results)
for metric_name in data['metrics']:
for metric_metric in data['metrics'][metric_name]:
value = data['metrics'][metric_name][metric_metric]
if isinstance(value, float) or isinstance(value, int):
value = round(value, 2)
data['metrics'][metric_name][metric_metric]=value
github_env.write(
f"{escape_metric_name(metric_name)}_{escape_metric_name(metric_metric)}={value}\n")
iterations = data['root_group']['checks']['success completion']['passes']
except Exception:
print("bench: error :")
traceback.print_exc(file=sys.stdout)
# Stop the server
if server_process:
try:
print(f"bench: shutting down server pid={server_process.pid} ...")
if os.name == 'nt':
interrupt = signal.CTRL_C_EVENT
else:
interrupt = signal.SIGINT
server_process.send_signal(interrupt)
server_process.wait(0.5)
except subprocess.TimeoutExpired:
print(f"server still alive after 500ms, force-killing pid={server_process.pid} ...")
server_process.kill() # SIGKILL
server_process.wait()
while is_server_listening(args.host, args.port):
time.sleep(0.1)
title = (f"llama.cpp {args.name} on {args.runner_label}\n "
f"duration={args.duration} {iterations} iterations")
xlabel = (f"{args.hf_repo}/{args.hf_file}\n"
f"parallel={args.parallel} ctx-size={args.ctx_size} ngl={args.n_gpu_layers} batch-size={args.batch_size} ubatch-size={args.ubatch_size} pp={args.max_prompt_tokens} pp+tg={args.max_tokens}\n"
f"branch={args.branch} commit={args.commit}")
# Prometheus
end_time = time.time()
prometheus_metrics = {}
if is_server_listening("0.0.0.0", 9090):
metrics = ['prompt_tokens_seconds', 'predicted_tokens_seconds',
'kv_cache_usage_ratio', 'requests_processing', 'requests_deferred']
for metric in metrics:
resp = requests.get(f"http://localhost:9090/api/v1/query_range",
params={'query': 'llamacpp:' + metric, 'start': start_time, 'end': end_time, 'step': 2})
with open(f"{metric}.json", 'w') as metric_json:
metric_json.write(resp.text)
if resp.status_code != 200:
print(f"bench: unable to extract prometheus metric {metric}: {resp.text}")
else:
metric_data = resp.json()
values = metric_data['data']['result'][0]['values']
timestamps, metric_values = zip(*values)
metric_values = [float(value) for value in metric_values]
prometheus_metrics[metric] = metric_values
timestamps_dt = [str(datetime.fromtimestamp(int(ts))) for ts in timestamps]
plt.figure(figsize=(16, 10), dpi=80)
plt.plot(timestamps_dt, metric_values, label=metric)
plt.xticks(rotation=0, fontsize=14, horizontalalignment='center', alpha=.7)
plt.yticks(fontsize=12, alpha=.7)
ylabel = f"llamacpp:{metric}"
plt.title(title,
fontsize=14, wrap=True)
plt.grid(axis='both', alpha=.3)
plt.ylabel(ylabel, fontsize=22)
plt.xlabel(xlabel, fontsize=14, wrap=True)
plt.gca().xaxis.set_major_locator(matplotlib.dates.MinuteLocator())
plt.gca().xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%Y-%m-%d %H:%M:%S"))
plt.gcf().autofmt_xdate()
# Remove borders
plt.gca().spines["top"].set_alpha(0.0)
plt.gca().spines["bottom"].set_alpha(0.3)
plt.gca().spines["right"].set_alpha(0.0)
plt.gca().spines["left"].set_alpha(0.3)
# Save the plot as a jpg image
plt.savefig(f'{metric}.jpg', dpi=60)
plt.close()
# Mermaid format in case images upload failed
with open(f"{metric}.mermaid", 'w') as mermaid_f:
mermaid = (
f"""---
config:
xyChart:
titleFontSize: 12
width: 900
height: 600
themeVariables:
xyChart:
titleColor: "#000000"
---
xychart-beta
title "{title}"
y-axis "llamacpp:{metric}"
x-axis "llamacpp:{metric}" {int(min(timestamps))} --> {int(max(timestamps))}
line [{', '.join([str(round(float(value), 2)) for value in metric_values])}]
""")
mermaid_f.write(mermaid)
# 140 chars max for commit status description
bench_results = {
"i": iterations,
"req": {
"p95": round(data['metrics']["http_req_duration"]["p(95)"], 2),
"avg": round(data['metrics']["http_req_duration"]["avg"], 2),
},
"pp": {
"p95": round(data['metrics']["llamacpp_prompt_processing_second"]["p(95)"], 2),
"avg": round(data['metrics']["llamacpp_prompt_processing_second"]["avg"], 2),
"0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2) if 'prompt_tokens_seconds' in prometheus_metrics else 0,
},
"tg": {
"p95": round(data['metrics']["llamacpp_tokens_second"]["p(95)"], 2),
"avg": round(data['metrics']["llamacpp_tokens_second"]["avg"], 2),
"0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2) if 'predicted_tokens_seconds' in prometheus_metrics else 0,
},
}
with open("results.github.env", 'a') as github_env:
github_env.write(f"BENCH_RESULTS={json.dumps(bench_results, indent=None, separators=(',', ':') )}\n")
github_env.write(f"BENCH_ITERATIONS={iterations}\n")
title = title.replace('\n', ' ')
xlabel = xlabel.replace('\n', ' ')
github_env.write(f"BENCH_GRAPH_TITLE={title}\n")
github_env.write(f"BENCH_GRAPH_XLABEL={xlabel}\n")
def start_benchmark(args):
k6_path = './k6'
if 'BENCH_K6_BIN_PATH' in os.environ:
k6_path = os.environ['BENCH_K6_BIN_PATH']
k6_args = [
'run', args.scenario,
'--no-color',
'--no-connection-reuse',
'--no-vu-connection-reuse',
]
k6_args.extend(['--duration', args.duration])
k6_args.extend(['--iterations', args.n_prompts])
k6_args.extend(['--vus', args.parallel])
k6_args.extend(['--summary-export', 'k6-results.json'])
k6_args.extend(['--out', 'csv=k6-results.csv'])
args = f"SERVER_BENCH_N_PROMPTS={args.n_prompts} SERVER_BENCH_MAX_PROMPT_TOKENS={args.max_prompt_tokens} SERVER_BENCH_MAX_CONTEXT={args.max_tokens} "
args = args + ' '.join([str(arg) for arg in [k6_path, *k6_args]])
print(f"bench: starting k6 with: {args}")
k6_completed = subprocess.run(args, shell=True, stdout=sys.stdout, stderr=sys.stderr)
if k6_completed.returncode != 0:
raise Exception("bench: unable to run k6")
def start_server(args):
server_process = start_server_background(args)
attempts = 0
max_attempts = 600
if 'GITHUB_ACTIONS' in os.environ:
max_attempts *= 2
while not is_server_listening(args.host, args.port):
attempts += 1
if attempts > max_attempts:
assert False, "server not started"
print(f"bench: waiting for server to start ...")
time.sleep(0.5)
attempts = 0
while not is_server_ready(args.host, args.port):
attempts += 1
if attempts > max_attempts:
assert False, "server not ready"
print(f"bench: waiting for server to be ready ...")
time.sleep(0.5)
print("bench: server started and ready.")
return server_process
def start_server_background(args):
# Start the server
server_path = '../../../build/bin/llama-server'
if 'LLAMA_SERVER_BIN_PATH' in os.environ:
server_path = os.environ['LLAMA_SERVER_BIN_PATH']
server_args = [
'--host', args.host,
'--port', args.port,
]
server_args.extend(['--hf-repo', args.hf_repo])
server_args.extend(['--hf-file', args.hf_file])
server_args.extend(['--n-gpu-layers', args.n_gpu_layers])
server_args.extend(['--ctx-size', args.ctx_size])
server_args.extend(['--parallel', args.parallel])
server_args.extend(['--batch-size', args.batch_size])
server_args.extend(['--ubatch-size', args.ubatch_size])
server_args.extend(['--n-predict', args.max_tokens * 2])
server_args.extend(['--defrag-thold', "0.1"])
server_args.append('--cont-batching')
server_args.append('--metrics')
server_args.append('--flash-attn')
args = [str(arg) for arg in [server_path, *server_args]]
print(f"bench: starting server with: {' '.join(args)}")
pkwargs = {
'stdout': subprocess.PIPE,
'stderr': subprocess.PIPE
}
server_process = subprocess.Popen(
args,
**pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue]
def server_log(in_stream, out_stream):
for line in iter(in_stream.readline, b''):
print(line.decode('utf-8'), end='', file=out_stream)
thread_stdout = threading.Thread(target=server_log, args=(server_process.stdout, sys.stdout))
thread_stdout.start()
thread_stderr = threading.Thread(target=server_log, args=(server_process.stderr, sys.stderr))
thread_stderr.start()
return server_process
def is_server_listening(server_fqdn, server_port):
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
result = sock.connect_ex((server_fqdn, server_port))
_is_server_listening = result == 0
if _is_server_listening:
print(f"server is listening on {server_fqdn}:{server_port}...")
return _is_server_listening
def is_server_ready(server_fqdn, server_port):
url = f"http://{server_fqdn}:{server_port}/health"
response = requests.get(url)
return response.status_code == 200
def escape_metric_name(metric_name):
return re.sub('[^A-Z0-9]', '_', metric_name.upper())
if __name__ == '__main__':
main()

View file

@ -0,0 +1,9 @@
global:
scrape_interval: 10s
external_labels:
llamacpp: 'server'
scrape_configs:
- job_name: 'llama.cpp server'
static_configs:
- targets: ['localhost:8080']

View file

@ -0,0 +1,2 @@
matplotlib
requests

View file

@ -0,0 +1,162 @@
import sse from 'k6/x/sse'
import {check, sleep} from 'k6'
import {SharedArray} from 'k6/data'
import {Counter, Rate, Trend} from 'k6/metrics'
import exec from 'k6/execution';
// Server chat completions prefix
const server_url = __ENV.SERVER_BENCH_URL ? __ENV.SERVER_BENCH_URL : 'http://localhost:8080/v1'
// Number of total prompts in the dataset - default 10m / 10 seconds/request * number of users
const n_prompt = __ENV.SERVER_BENCH_N_PROMPTS ? parseInt(__ENV.SERVER_BENCH_N_PROMPTS) : 600 / 10 * 8
// Model name to request
const model = __ENV.SERVER_BENCH_MODEL_ALIAS ? __ENV.SERVER_BENCH_MODEL_ALIAS : 'my-model'
// Dataset path
const dataset_path = __ENV.SERVER_BENCH_DATASET ? __ENV.SERVER_BENCH_DATASET : './ShareGPT_V3_unfiltered_cleaned_split.json'
// Max tokens to predict
const max_tokens = __ENV.SERVER_BENCH_MAX_TOKENS ? parseInt(__ENV.SERVER_BENCH_MAX_TOKENS) : 512
// Max prompt tokens
const n_prompt_tokens = __ENV.SERVER_BENCH_MAX_PROMPT_TOKENS ? parseInt(__ENV.SERVER_BENCH_MAX_PROMPT_TOKENS) : 1024
// Max slot context
const n_ctx_slot = __ENV.SERVER_BENCH_MAX_CONTEXT ? parseInt(__ENV.SERVER_BENCH_MAX_CONTEXT) : 2048
export function setup() {
console.info(`Benchmark config: server_url=${server_url} n_prompt=${n_prompt} model=${model} dataset_path=${dataset_path} max_tokens=${max_tokens}`)
}
const data = new SharedArray('conversations', function () {
const tokenizer = (message) => message.split(/[\s,'".?]/)
return JSON.parse(open(dataset_path))
// Filter out the conversations with less than 2 turns.
.filter(data => data["conversations"].length >= 2)
.filter(data => data["conversations"][0]["from"] === "human")
.map(data => {
return {
prompt: data["conversations"][0]["value"],
n_prompt_tokens: tokenizer(data["conversations"][0]["value"]).length,
n_completion_tokens: tokenizer(data["conversations"][1]["value"]).length,
}
})
// Filter out too short sequences
.filter(conv => conv.n_prompt_tokens >= 4 && conv.n_completion_tokens >= 4)
// Filter out too long sequences.
.filter(conv => conv.n_prompt_tokens <= n_prompt_tokens && conv.n_prompt_tokens + conv.n_completion_tokens <= n_ctx_slot)
// Keep only first n prompts
.slice(0, n_prompt)
})
const llamacpp_prompt_tokens = new Trend('llamacpp_prompt_tokens')
const llamacpp_completion_tokens = new Trend('llamacpp_completion_tokens')
const llamacpp_tokens_second = new Trend('llamacpp_tokens_second')
const llamacpp_prompt_processing_second = new Trend('llamacpp_prompt_processing_second')
const llamacpp_emit_first_token_second = new Trend('llamacpp_emit_first_token_second')
const llamacpp_prompt_tokens_total_counter = new Counter('llamacpp_prompt_tokens_total_counter')
const llamacpp_completion_tokens_total_counter = new Counter('llamacpp_completion_tokens_total_counter')
const llamacpp_completions_truncated_rate = new Rate('llamacpp_completions_truncated_rate')
const llamacpp_completions_stop_rate = new Rate('llamacpp_completions_stop_rate')
export const options = {
thresholds: {
llamacpp_completions_truncated_rate: [
// more than 80% of truncated input will abort the test
{threshold: 'rate < 0.8', abortOnFail: true, delayAbortEval: '1m'},
],
},
duration: '10m',
vus: 8,
}
export default function () {
const conversation = data[exec.scenario.iterationInInstance % data.length]
const payload = {
"messages": [
{
"role": "system",
"content": "You are ChatGPT, an AI assistant.",
},
{
"role": "user",
"content": conversation.prompt,
}
],
"model": model,
"stream": true,
"stream_options": {
"include_usage": true, // False to be supported in llama.cpp server
},
"seed": 42,
"max_tokens": max_tokens,
"stop": ["<|im_end|>"] // This is temporary for phi-2 base (i.e. not instructed) since the server expects that the model always to emit BOS
}
const params = {method: 'POST', body: JSON.stringify(payload)};
const startTime = new Date()
let promptEvalEndTime = null
let prompt_tokens = 0
let completions_tokens = 0
let finish_reason = null
const res = sse.open(`${server_url}/chat/completions`, params, function (client) {
client.on('event', function (event) {
if (promptEvalEndTime == null) {
promptEvalEndTime = new Date()
llamacpp_emit_first_token_second.add((promptEvalEndTime - startTime) / 1.e3)
}
if (event.data === '[DONE]' || event.data === '') {
return
}
let chunk = JSON.parse(event.data)
if (chunk.choices && chunk.choices.length > 0) {
let choice = chunk.choices[0]
if (choice.finish_reason) {
finish_reason = choice.finish_reason
}
}
if (chunk.usage) {
prompt_tokens = chunk.usage.prompt_tokens
llamacpp_prompt_tokens.add(prompt_tokens)
llamacpp_prompt_tokens_total_counter.add(prompt_tokens)
completions_tokens = chunk.usage.completion_tokens
llamacpp_completion_tokens.add(completions_tokens)
llamacpp_completion_tokens_total_counter.add(completions_tokens)
}
})
client.on('error', function (e) {
console.log('An unexpected error occurred: ', e.error());
throw e;
})
})
check(res, {'success completion': (r) => r.status === 200})
const endTime = new Date()
const promptEvalTime = promptEvalEndTime - startTime
if (promptEvalTime > 0) {
llamacpp_prompt_processing_second.add(prompt_tokens / (promptEvalEndTime - startTime) * 1.e3)
}
const completion_time = endTime - promptEvalEndTime
if (completions_tokens > 0 && completion_time > 0) {
llamacpp_tokens_second.add(completions_tokens / completion_time * 1.e3)
}
llamacpp_completions_truncated_rate.add(finish_reason === 'length')
llamacpp_completions_stop_rate.add(finish_reason === 'stop')
sleep(0.3)
}

109
tools/server/chat-llama2.sh Executable file
View file

@ -0,0 +1,109 @@
#!/bin/bash
API_URL="${API_URL:-http://127.0.0.1:8080}"
CHAT=(
"Hello, Assistant."
"Hello. How may I help you today?"
)
INSTRUCTION="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
trim() {
shopt -s extglob
set -- "${1##+([[:space:]])}"
printf "%s" "${1%%+([[:space:]])}"
}
trim_trailing() {
shopt -s extglob
printf "%s" "${1%%+([[:space:]])}"
}
format_prompt() {
if [[ "${#CHAT[@]}" -eq 0 ]]; then
echo -n "[INST] <<SYS>>\n${INSTRUCTION}\n<</SYS>>"
else
LAST_INDEX=$(( ${#CHAT[@]} - 1 ))
echo -n "${CHAT[$LAST_INDEX]}\n[INST] $1 [/INST]"
fi
}
tokenize() {
curl \
--silent \
--request POST \
--url "${API_URL}/tokenize" \
--header "Content-Type: application/json" \
--data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \
| jq '.tokens[]'
}
N_KEEP=$(tokenize "[INST] <<SYS>>\n${INSTRUCTION}\n<</SYS>>" | wc -l)
chat_completion() {
PROMPT="$(trim_trailing "$(format_prompt "$1")")"
DATA="$(echo -n "$PROMPT" | jq -Rs --argjson n_keep $N_KEEP '{
prompt: .,
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: $n_keep,
n_predict: 1024,
stop: ["[INST]"],
stream: true
}')"
# Create a temporary file to hold the Python output
TEMPFILE=$(mktemp)
exec 3< <(curl \
--silent \
--no-buffer \
--request POST \
--url "${API_URL}/completion" \
--header "Content-Type: application/json" \
--data-raw "${DATA}")
python -c "
import json
import sys
answer = ''
while True:
line = sys.stdin.readline()
if not line:
break
if line.startswith('data: '):
json_content = line[6:].strip()
content = json.loads(json_content)['content']
sys.stdout.write(content)
sys.stdout.flush()
answer += content
answer = answer.rstrip('\n')
# Write the answer to the temporary file
with open('$TEMPFILE', 'w') as f:
f.write(answer)
" <&3
exec 3<&-
# Read the answer from the temporary file
ANSWER=$(cat $TEMPFILE)
# Clean up the temporary file
rm $TEMPFILE
printf "\n"
CHAT+=("$1" "$(trim "$ANSWER")")
}
while true; do
echo -en "\033[0;32m" # Green color
read -r -e -p "> " QUESTION
echo -en "\033[0m" # Reset color
chat_completion "${QUESTION}"
done

131
tools/server/chat.mjs Normal file
View file

@ -0,0 +1,131 @@
import * as readline from 'node:readline'
import { stdin, stdout } from 'node:process'
import { readFileSync } from 'node:fs'
import { SchemaConverter } from './public_legacy/json-schema-to-grammar.mjs'
const args = process.argv.slice(2);
const grammarJsonSchemaFile = args.find(
(_, index) => args[index - 1] === "--grammar-json-schema"
);
const no_cached_prompt = args.find(
(_, index) => args[index - 1] === "--no-cache-prompt"
) ?? "false";
const grammarFile = args.find((_, index) => args[index - 1] === "--grammar");
// Example usage: function,arguments
const grammarJsonSchemaPropOrder = args.find(
(_, index) => args[index - 1] === "--grammar-json-schema-prop-order"
);
const propOrder = grammarJsonSchemaPropOrder
? grammarJsonSchemaPropOrder
.split(",")
.reduce((acc, cur, index) => ({ ...acc, [cur]: index }), {})
: {};
let grammar = null
if (grammarJsonSchemaFile) {
let schema = JSON.parse(readFileSync(grammarJsonSchemaFile, 'utf-8'))
const converter = new SchemaConverter({prop_order: propOrder, allow_fetch: true})
schema = await converter.resolveRefs(schema, grammarJsonSchemaFile)
converter.visit(schema, '')
grammar = converter.formatGrammar()
}
if (grammarFile) {
grammar = readFileSync(grammarFile, 'utf-8')
}
// for cached prompt
let slot_id = -1;
const API_URL = 'http://127.0.0.1:8080'
const chat = [
{
human: "Hello, Assistant.",
assistant: "Hello. How may I help you today?"
},
{
human: "Please tell me the largest city in Europe.",
assistant: "Sure. The largest city in Europe is Moscow, the capital of Russia."
},
]
const instruction = `A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.`
function format_prompt(question) {
return `${instruction}\n${
chat.map(m =>`### Human: ${m.human}\n### Assistant: ${m.assistant}`).join("\n")
}\n### Human: ${question}\n### Assistant:`
}
async function tokenize(content) {
const result = await fetch(`${API_URL}/tokenize`, {
method: 'POST',
body: JSON.stringify({ content })
})
if (!result.ok) {
return []
}
return await result.json().tokens
}
const n_keep = await tokenize(instruction).length
async function chat_completion(question) {
const result = await fetch(`${API_URL}/completion`, {
method: 'POST',
body: JSON.stringify({
prompt: format_prompt(question),
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: n_keep,
n_predict: 256,
cache_prompt: no_cached_prompt === "false",
slot_id: slot_id,
stop: ["\n### Human:"], // stop completion after generating this
grammar,
stream: true,
})
})
if (!result.ok) {
return
}
let answer = ''
for await (var chunk of result.body) {
const t = Buffer.from(chunk).toString('utf8')
if (t.startsWith('data: ')) {
const message = JSON.parse(t.substring(6))
slot_id = message.slot_id
answer += message.content
process.stdout.write(message.content)
if (message.stop) {
if (message.truncated) {
chat.shift()
}
break
}
}
}
process.stdout.write('\n')
chat.push({ human: question, assistant: answer.trimStart() })
}
const rl = readline.createInterface({ input: stdin, output: stdout });
const readlineQuestion = (rl, query, options) => new Promise((resolve, reject) => {
rl.question(query, options, resolve)
});
while(true) {
const question = await readlineQuestion(rl, '> ')
await chat_completion(question)
}

80
tools/server/chat.sh Executable file
View file

@ -0,0 +1,80 @@
#!/bin/bash
API_URL="${API_URL:-http://127.0.0.1:8080}"
CHAT=(
"Hello, Assistant."
"Hello. How may I help you today?"
"Please tell me the largest city in Europe."
"Sure. The largest city in Europe is Moscow, the capital of Russia."
)
INSTRUCTION="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
trim() {
shopt -s extglob
set -- "${1##+([[:space:]])}"
printf "%s" "${1%%+([[:space:]])}"
}
trim_trailing() {
shopt -s extglob
printf "%s" "${1%%+([[:space:]])}"
}
format_prompt() {
echo -n "${INSTRUCTION}"
printf "\n### Human: %s\n### Assistant: %s" "${CHAT[@]}" "$1"
}
tokenize() {
curl \
--silent \
--request POST \
--url "${API_URL}/tokenize" \
--header "Content-Type: application/json" \
--data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \
| jq '.tokens[]'
}
N_KEEP=$(tokenize "${INSTRUCTION}" | wc -l)
chat_completion() {
PROMPT="$(trim_trailing "$(format_prompt "$1")")"
DATA="$(echo -n "$PROMPT" | jq -Rs --argjson n_keep $N_KEEP '{
prompt: .,
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: $n_keep,
n_predict: 256,
cache_prompt: true,
stop: ["\n### Human:"],
stream: true
}')"
ANSWER=''
while IFS= read -r LINE; do
if [[ $LINE = data:* ]]; then
CONTENT="$(echo "${LINE:5}" | jq -r '.content')"
printf "%s" "${CONTENT}"
ANSWER+="${CONTENT}"
fi
done < <(curl \
--silent \
--no-buffer \
--request POST \
--url "${API_URL}/completion" \
--header "Content-Type: application/json" \
--data-raw "${DATA}")
printf "\n"
CHAT+=("$1" "$(trim "$ANSWER")")
}
while true; do
read -r -e -p "> " QUESTION
chat_completion "${QUESTION}"
done

10506
tools/server/httplib.h Normal file

File diff suppressed because it is too large Load diff

Binary file not shown.

View file

@ -0,0 +1,12 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="refresh" content="5">
</head>
<body>
<div id="loading">
The model is loading. Please wait.<br/>
The user interface will appear soon.
</div>
</body>
</html>

View file

@ -0,0 +1,402 @@
@import url("theme-snowstorm.css");
@import url("theme-polarnight.css");
@import url("theme-ketivah.css");
@import url("theme-mangotango.css");
@import url("theme-playground.css");
@import url("theme-beeninorder.css");
:root {
/* ---------- PRIMARY COLORS ----------------- */
--primary-color-1: hsl(217.5, 26.7%, 94.1%);
--primary-color-1-hue: 217.5;
--primary-color-1-saturation: 26.7%;
--primary-color-1-lightness: 94.1%;
--primary-color-2: hsl(218.2, 26.8%, 92.0%);
--primary-color-2-hue: 218.2;
--primary-color-2-saturation: 26.8%;
--primary-color-2-lightness: 92.0%;
--primary-color-3: hsl(218.8, 27.9%, 88.0%);
--primary-color-3-hue: 218.8;
--primary-color-3-saturation: 27.9%;
--primary-color-3-lightness: 88.0%;
--primary-color-4: hsl(218.8, 18.3%, 81.8%);
--primary-color-4-hue: 218.8;
--primary-color-4-saturation: 18.3%;
--primary-color-4-lightness: 81.8%;
/* ---------- SECONDARY COLORS --------------- */
--secondary-color-1: hsl(220.0, 16.4%, 21.6%);
--secondary-color-1-hue: 220.0;
--secondary-color-1-saturation: 16.4%;
--secondary-color-1-lightness: 21.6%;
--secondary-color-2: hsl(221.7, 16.3%, 27.6%);
--secondary-color-2-hue: 221.7;
--secondary-color-2-saturation: 16.3%;
--secondary-color-2-lightness: 27.6%;
--secondary-color-3: hsl(220.0, 16.8%, 31.6%);
--secondary-color-3-hue: 220.0;
--secondary-color-3-saturation: 16.8%;
--secondary-color-3-lightness: 31.6%;
--secondary-color-4: hsl(220.0, 16.5%, 35.7%);
--secondary-color-4-hue: 220.0;
--secondary-color-4-saturation: 16.5%;
--secondary-color-4-lightness: 35.7%;
/* ----------- NUANCES COLORS ---------------- */
--theme-nuance-color-1: hsl(178.7, 25.1%, 64.9%);
--theme-nuance-color-1-hue: 178.7;
--theme-nuance-color-1-saturation: 25.1%;
--theme-nuance-color-1-lightness: 64.9%;
--theme-nuance-color-2: hsl(193.3, 43.4%, 67.5%);
--theme-nuance-color-2-hue: 193.3;
--theme-nuance-color-2-saturation: 43.4%;
--theme-nuance-color-2-lightness: 67.5%;
--theme-nuance-color-3: hsl(210.0, 34.0%, 63.1%);
--theme-nuance-color-3-hue: 210.0;
--theme-nuance-color-3-saturation: 34.0%;
--theme-nuance-color-3-lightness: 63.1%;
--theme-nuance-color-4: hsl(213.1, 32.0%, 52.2%);
--theme-nuance-color-4-hue: 213.1;
--theme-nuance-color-4-saturation: 32.0%;
--theme-nuance-color-4-lightness: 52.2%;
/* ----------- ROYGP COLORS ------------------ */
--theme-red-color: hsl(32.5, 80%, 50%);
--theme-orange-color: hsl(32.5, 70%, 45%);
--theme-yellow-color: hsl(40.0, 0.6%, 73.3%);
--theme-green-color: hsl(92.4, 27.8%, 64.7%);
--theme-purple-color: hsl(311.1, 20.2%, 63.1%);
/* ------------------------------------------- */
--background-color-1: var(--primary-color-1);
--background-color-2: var(--primary-color-2);
--background-color-3: var(--primary-color-3);
--background-color-4: var(--primary-color-4);
--border-color-1: var(--primary-color-2);
--border-color-2: var(--primary-color-3);
--border-color-3: var(--primary-color-4);
--border-focus-color: var(--theme-nuance-color-2);
--border-focus-shadow: var(--theme-nuance-color-1);
--text-color-plain: var(--secondary-color-1);
--text-color-subtile-1: var(--secondary-color-2);
--text-color-subtile-2: var(--secondary-color-3);
--code-background-color: var(--secondary-color-2);
--code-text-color: var(--primary-color-2);
--ui-range-thumb-color: var(--theme-nuance-color-3);
--ui-range-thumb-border: var(--ui-ranger-thumb-color);
--textarea-border-color: var(--secondary-color-4);
--chat-id-color: var(--theme-nuance-color-4);
/* ------------------------------------------- */
--button-alert-text-hover: var(--primary-color-1);
--button-alert-color-hover: var(--theme-orange-color);
--button-alert-border-hover: var(--theme-orange-color);
--button-alert-text-active: var(--primary-color-1);
--button-alert-color-active: var(--theme-red-color);
--button-alert-border-active: var(--theme-red-color);
/* ----------- PRIMARY BUTTONS --------------- */
/* - button should immediately catch the eye - */
--button-primary-text: var(--secondary-color-1);
--button-primary-color: var(--theme-nuance-color-3);
--button-primary-border: var(--theme-nuance-color-3);
/* ---------hover---------- */
--button-primary-text-hover:
hsl(217.5,
calc(var(--secondary-color-1-saturation) + 35%),
calc(var(--secondary-color-1-lightness) - 30%));
--button-primary-color-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 2%),
calc(var(--theme-nuance-color-3-lightness) - 10%));
--button-primary-border-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 2%),
calc(var(--theme-nuance-color-3-lightness) - 10%));
/* ---------active--------- */
--button-primary-text-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) + 35%));
--button-primary-color-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 10%),
calc(var(--theme-nuance-color-3-lightness) - 25%));
--button-primary-border-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 10%),
calc(var(--theme-nuance-color-3-lightness) - 25%));
/* ---------- SECONDARY BUTTONS -------------- */
/* these should NOT immediately catch the eye */
--button-secondary-text:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) - 50%));
--button-secondary-color:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) + 10%));
--button-secondary-border:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) + 10%));
/* ---------hover---------- */
--button-secondary-text-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) - 80%));
--button-secondary-color-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 22%),
calc(var(--theme-nuance-color-3-lightness) + 1%));
--button-secondary-border-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 22%),
calc(var(--theme-nuance-color-3-lightness) + 1%));
/* ---------active--------- */
--button-secondary-text-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) + 40%),
calc(var(--theme-nuance-color-3-lightness) - 55%));
--button-secondary-color-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 30%),
calc(var(--theme-nuance-color-3-lightness) - 5%));
--button-secondary-border-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 30%),
calc(var(--theme-nuance-color-3-lightness) - 5%));
/* ---------- TERTIARY BUTTONS --------------- */
/* ---------- disabled buttons --------------- */
--button-tertiary-text:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) - 5%));
--button-tertiary-color:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) + 20%));
--button-tertiary-border:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) + 20%));
/* ---------hover---------- */
--button-tertiary-text-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) - 5%));
--button-tertiary-color-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) + 20%));
--button-tertiary-border-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) + 20%));
}
/*
.theme-template {
If light theme: should go from bright to darker
If dark theme: should go from dark to brighter
ideally this should not be anything but steps of
gray or slightly variants from it
--primary-color-1: #2E3440;
--primary-color-2: #3B4252;
--primary-color-3: #434C5E;
--primary-color-4: #4C566A;
If light theme: should go from dark to brighter
If dark theme: should go from bright to darker
ideally this should not be anything but steps of
gray or slightly variants from it
--secondary-color-1: #ECEFF4;
--secondary-color-2: #E5E9F0;
--secondary-color-3: #D8DEE9;
--secondary-color-4: #C8CED9;
Choose wisely nuance colors. It is not easy to find
4 harmonizing nuance colors. But keep in mind, that
only one accent color could work too.
--theme-nuance-color-1: #8FBCBB;
--theme-nuance-color-2: #88C0D0;
--theme-nuance-color-3: #81A1C1;
--theme-nuance-color-4: #5E81AC;
adapt the color red, orange, yellow, green,
purple to the 'mood' of your overall design
e.g is it low-contrast? vibrant? dynamic? etc
--theme-red-color: #BF616A;
--theme-orange-color: #D08770;
--theme-yellow-color: #EBCB8B;
--theme-green-color: #A3BE8C;
--theme-purple-color: #B48EAD;
NOTE: comment all those line `--- ...` out
------------------------------------------------
--background-color-1:
--background-color-2:
--background-color-3:
--background-color-4:
--border-color-1:
--border-color-2:
--border-color-3:
--border-focus-color:
--border-focus-shadow:
--text-color-plain:
--text-color-subtile-1:
--text-color-subtile-2:
--code-background-color:
--code-text-color:
--ui-range-thumb-color:
--ui-range-thumb-border:
--textarea-border-color:
-------------------------------------------
--button-alert-text-hover:
--button-alert-color-hover:
--button-alert-border-hover:
--button-alert-text-active:
--button-alert-color-active:
--button-alert-border-active:
----------- PRIMARY -----------------------
--button should immediately catch the eye--
--button-primary-text:
--button-primary-color:
--button-primary-border:
---------hover----------
--button-primary-text-hover:
--button-primary-color-hover:
--button-primary-border-hover:
---------active---------
--button-primary-text-active:
--button-primary-color-active:
--button-primary-border-active:
------------ SECONDARY ------------------------
--button should NOT immediately catch the eye--
--button-secondary-text:
--button-secondary-color:
--button-secondary-border:
---------hover----------
--button-secondary-text-hover:
--button-secondary-color-hover:
--button-secondary-border-hover:
---------active---------
--button-secondary-text-active:
--button-secondary-color-active:
--button-secondary-border-active:
---------- TERTIARY -----------------------
---------- disabled buttons ---------------
--button-tertiary-text:
--button-tertiary-color:
--button-tertiary-border:
---------hover----------
--button-tertiary-text:
--button-tertiary-color:
--button-tertiary-border:
}
*/

View file

@ -0,0 +1,209 @@
const paramDefaults = {
stream: true,
n_predict: 500,
temperature: 0.2,
stop: ["</s>"]
};
let generation_settings = null;
// Completes the prompt as a generator. Recommended for most use cases.
//
// Example:
//
// import { llama } from '/completion.js'
//
// const request = llama("Tell me a joke", {n_predict: 800})
// for await (const chunk of request) {
// document.write(chunk.data.content)
// }
//
export async function* llama(prompt, params = {}, config = {}) {
let controller = config.controller;
const api_url = config.api_url?.replace(/\/+$/, '') || "";
if (!controller) {
controller = new AbortController();
}
const completionParams = { ...paramDefaults, ...params, prompt };
const response = await fetch(`${api_url}${config.endpoint || '/completion'}`, {
method: 'POST',
body: JSON.stringify(completionParams),
headers: {
'Connection': 'keep-alive',
'Content-Type': 'application/json',
'Accept': 'text/event-stream',
...(params.api_key ? {'Authorization': `Bearer ${params.api_key}`} : {})
},
signal: controller.signal,
});
const reader = response.body.getReader();
const decoder = new TextDecoder();
let content = "";
let leftover = ""; // Buffer for partially read lines
try {
let cont = true;
while (cont) {
const result = await reader.read();
if (result.done) {
break;
}
// Add any leftover data to the current chunk of data
const text = leftover + decoder.decode(result.value);
// Check if the last character is a line break
const endsWithLineBreak = text.endsWith('\n');
// Split the text into lines
let lines = text.split('\n');
// If the text doesn't end with a line break, then the last line is incomplete
// Store it in leftover to be added to the next chunk of data
if (!endsWithLineBreak) {
leftover = lines.pop();
} else {
leftover = ""; // Reset leftover if we have a line break at the end
}
// Parse all sse events and add them to result
const regex = /^(\S+):\s(.*)$/gm;
for (const line of lines) {
const match = regex.exec(line);
if (match) {
result[match[1]] = match[2];
if (result.data === '[DONE]') {
cont = false;
break;
}
// since we know this is llama.cpp, let's just decode the json in data
if (result.data) {
result.data = JSON.parse(result.data);
content += result.data.content;
// yield
yield result;
// if we got a stop token from server, we will break here
if (result.data.stop) {
if (result.data.generation_settings) {
generation_settings = result.data.generation_settings;
}
cont = false;
break;
}
}
if (result.error) {
try {
result.error = JSON.parse(result.error);
if (result.error.message.includes('slot unavailable')) {
// Throw an error to be caught by upstream callers
throw new Error('slot unavailable');
} else {
console.error(`llama.cpp error [${result.error.code} - ${result.error.type}]: ${result.error.message}`);
}
} catch(e) {
console.error(`llama.cpp error ${result.error}`)
}
}
}
}
}
} catch (e) {
if (e.name !== 'AbortError') {
console.error("llama error: ", e);
}
throw e;
}
finally {
controller.abort();
}
return content;
}
// Call llama, return an event target that you can subscribe to
//
// Example:
//
// import { llamaEventTarget } from '/completion.js'
//
// const conn = llamaEventTarget(prompt)
// conn.addEventListener("message", (chunk) => {
// document.write(chunk.detail.content)
// })
//
export const llamaEventTarget = (prompt, params = {}, config = {}) => {
const eventTarget = new EventTarget();
(async () => {
let content = "";
for await (const chunk of llama(prompt, params, config)) {
if (chunk.data) {
content += chunk.data.content;
eventTarget.dispatchEvent(new CustomEvent("message", { detail: chunk.data }));
}
if (chunk.data.generation_settings) {
eventTarget.dispatchEvent(new CustomEvent("generation_settings", { detail: chunk.data.generation_settings }));
}
if (chunk.data.timings) {
eventTarget.dispatchEvent(new CustomEvent("timings", { detail: chunk.data.timings }));
}
}
eventTarget.dispatchEvent(new CustomEvent("done", { detail: { content } }));
})();
return eventTarget;
}
// Call llama, return a promise that resolves to the completed text. This does not support streaming
//
// Example:
//
// llamaPromise(prompt).then((content) => {
// document.write(content)
// })
//
// or
//
// const content = await llamaPromise(prompt)
// document.write(content)
//
export const llamaPromise = (prompt, params = {}, config = {}) => {
return new Promise(async (resolve, reject) => {
let content = "";
try {
for await (const chunk of llama(prompt, params, config)) {
content += chunk.data.content;
}
resolve(content);
} catch (error) {
reject(error);
}
});
};
/**
* (deprecated)
*/
export const llamaComplete = async (params, controller, callback) => {
for await (const chunk of llama(params.prompt, params, { controller })) {
callback(chunk);
}
}
// Get the model info from the server. This is useful for getting the context window and so on.
export const llamaModelInfo = async (config = {}) => {
if (!generation_settings) {
const api_url = config.api_url?.replace(/\/+$/, '') || "";
const props = await fetch(`${api_url}/props`).then(r => r.json());
generation_settings = props.default_generation_settings;
}
return generation_settings;
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 4 KiB

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,838 @@
// WARNING: This file was ported from json_schema_to_grammar.py, please fix bugs / add features there first.
const SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}';
function _buildRepetition(itemRule, minItems, maxItems, opts={}) {
if (maxItems == 0) {
return '';
}
if (minItems === 0 && maxItems === 1) {
return `${itemRule}?`;
}
const separatorRule = opts.separatorRule ?? '';
const itemRuleIsLiteral = opts.itemRuleIsLiteral ?? false
if (separatorRule === '') {
if (minItems === 1 && maxItems === undefined) {
return `${itemRule}+`;
} else if (minItems === 0 && maxItems === undefined) {
return `${itemRule}*`;
} else {
return `${itemRule}{${minItems},${maxItems !== undefined ? maxItems : ''}}`;
}
}
const result = itemRule + ' ' + _buildRepetition(`(${separatorRule} ${itemRule})`, minItems > 0 ? minItems - 1 : 0, maxItems !== undefined ? maxItems - 1 : undefined);
return minItems === 0 ? `(${result})?` : result;
}
function _generateMinMaxInt(minValue, maxValue, out, decimalsLeft = 16, topLevel = true) {
const hasMin = minValue !== null;
const hasMax = maxValue !== null;
function digitRange(fromChar, toChar) {
out.push("[");
if (fromChar === toChar) {
out.push(fromChar);
} else {
out.push(fromChar);
out.push("-");
out.push(toChar);
}
out.push("]");
}
function moreDigits(minDigits, maxDigits) {
out.push("[0-9]");
if (minDigits === maxDigits && minDigits === 1) {
return;
}
out.push("{");
out.push(minDigits.toString());
if (maxDigits !== minDigits) {
out.push(",");
if (maxDigits !== Number.MAX_SAFE_INTEGER) {
out.push(maxDigits.toString());
}
}
out.push("}");
}
function uniformRange(fromStr, toStr) {
let i = 0;
while (i < fromStr.length && fromStr[i] === toStr[i]) {
i++;
}
if (i > 0) {
out.push("\"");
out.push(fromStr.slice(0, i));
out.push("\"");
}
if (i < fromStr.length) {
if (i > 0) {
out.push(" ");
}
const subLen = fromStr.length - i - 1;
if (subLen > 0) {
const fromSub = fromStr.slice(i + 1);
const toSub = toStr.slice(i + 1);
const subZeros = "0".repeat(subLen);
const subNines = "9".repeat(subLen);
let toReached = false;
out.push("(");
if (fromSub === subZeros) {
digitRange(fromStr[i], String.fromCharCode(toStr.charCodeAt(i) - 1));
out.push(" ");
moreDigits(subLen, subLen);
} else {
out.push("[");
out.push(fromStr[i]);
out.push("] ");
out.push("(");
uniformRange(fromSub, subNines);
out.push(")");
if (fromStr.charCodeAt(i) < toStr.charCodeAt(i) - 1) {
out.push(" | ");
if (toSub === subNines) {
digitRange(String.fromCharCode(fromStr.charCodeAt(i) + 1), toStr[i]);
toReached = true;
} else {
digitRange(String.fromCharCode(fromStr.charCodeAt(i) + 1), String.fromCharCode(toStr.charCodeAt(i) - 1));
}
out.push(" ");
moreDigits(subLen, subLen);
}
}
if (!toReached) {
out.push(" | ");
digitRange(toStr[i], toStr[i]);
out.push(" ");
uniformRange(subZeros, toSub);
}
out.push(")");
} else {
out.push("[");
out.push(fromStr[i]);
out.push("-");
out.push(toStr[i]);
out.push("]");
}
}
}
if (hasMin && hasMax) {
if (minValue < 0 && maxValue < 0) {
out.push("\"-\" (");
_generateMinMaxInt(-maxValue, -minValue, out, decimalsLeft, true);
out.push(")");
return;
}
if (minValue < 0) {
out.push("\"-\" (");
_generateMinMaxInt(0, -minValue, out, decimalsLeft, true);
out.push(") | ");
minValue = 0;
}
let minS = minValue.toString();
const maxS = maxValue.toString();
const minDigits = minS.length;
const maxDigits = maxS.length;
for (let digits = minDigits; digits < maxDigits; digits++) {
uniformRange(minS, "9".repeat(digits));
minS = "1" + "0".repeat(digits);
out.push(" | ");
}
uniformRange(minS, maxS);
return;
}
const lessDecimals = Math.max(decimalsLeft - 1, 1);
if (hasMin) {
if (minValue < 0) {
out.push("\"-\" (");
_generateMinMaxInt(null, -minValue, out, decimalsLeft, false);
out.push(") | [0] | [1-9] ");
moreDigits(0, decimalsLeft - 1);
} else if (minValue === 0) {
if (topLevel) {
out.push("[0] | [1-9] ");
moreDigits(0, lessDecimals);
} else {
moreDigits(1, decimalsLeft);
}
} else if (minValue <= 9) {
const c = minValue.toString();
const range_start = topLevel ? '1' : '0';
if (c > range_start) {
digitRange(range_start, String.fromCharCode(c.charCodeAt(0) - 1));
out.push(" ");
moreDigits(1, lessDecimals);
out.push(" | ");
}
digitRange(c, "9");
out.push(" ");
moreDigits(0, lessDecimals);
} else {
const minS = minValue.toString();
const length = minS.length;
const c = minS[0];
if (c > "1") {
digitRange(topLevel ? "1" : "0", String.fromCharCode(c.charCodeAt(0) - 1));
out.push(" ");
moreDigits(length, lessDecimals);
out.push(" | ");
}
digitRange(c, c);
out.push(" (");
_generateMinMaxInt(parseInt(minS.slice(1)), null, out, lessDecimals, false);
out.push(")");
if (c < "9") {
out.push(" | ");
digitRange(String.fromCharCode(c.charCodeAt(0) + 1), "9");
out.push(" ");
moreDigits(length - 1, lessDecimals);
}
}
return;
}
if (hasMax) {
if (maxValue >= 0) {
if (topLevel) {
out.push("\"-\" [1-9] ");
moreDigits(0, lessDecimals);
out.push(" | ");
}
_generateMinMaxInt(0, maxValue, out, decimalsLeft, true);
} else {
out.push("\"-\" (");
_generateMinMaxInt(-maxValue, null, out, decimalsLeft, false);
out.push(")");
}
return;
}
throw new Error("At least one of minValue or maxValue must be set");
}
class BuiltinRule {
constructor(content, deps) {
this.content = content;
this.deps = deps || [];
}
}
const PRIMITIVE_RULES = {
boolean : new BuiltinRule('("true" | "false") space', []),
'decimal-part' : new BuiltinRule('[0-9]{1,16}', []),
'integral-part': new BuiltinRule('[0] | [1-9] [0-9]{0,15}', []),
number : new BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
integer : new BuiltinRule('("-"? integral-part) space', ['integral-part']),
value : new BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
object : new BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
array : new BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
uuid : new BuiltinRule('"\\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\\"" space', []),
char : new BuiltinRule(`[^"\\\\\\x7F\\x00-\\x1F] | [\\\\] (["\\\\bfnrt] | "u" [0-9a-fA-F]{4})`, []),
string : new BuiltinRule(`"\\"" char* "\\"" space`, ['char']),
null : new BuiltinRule('"null" space', []),
};
// TODO: support "uri", "email" string formats
const STRING_FORMAT_RULES = {
'date' : new BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
'time' : new BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
'date-time' : new BuiltinRule('date "T" time', ['date', 'time']),
'date-string' : new BuiltinRule('"\\"" date "\\"" space', ['date']),
'time-string' : new BuiltinRule('"\\"" time "\\"" space', ['time']),
'date-time-string': new BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
}
const RESERVED_NAMES = {'root': true, ...PRIMITIVE_RULES, ...STRING_FORMAT_RULES};
const INVALID_RULE_CHARS_RE = /[^\dA-Za-z-]+/g;
const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"]/g;
const GRAMMAR_RANGE_LITERAL_ESCAPE_RE = /[\n\r"\]\-\\]/g;
const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]' };
const NON_LITERAL_SET = new Set('|.()[]{}*+?');
const ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = new Set('^$.[]()|{}*+?');
export class SchemaConverter {
constructor(options) {
this._propOrder = options.prop_order || {};
this._allowFetch = options.allow_fetch || false;
this._dotall = options.dotall || false;
this._rules = {'space': SPACE_RULE};
this._refs = {};
this._refsBeingResolved = new Set();
}
_formatLiteral(literal) {
const escaped = literal.replace(
GRAMMAR_LITERAL_ESCAPE_RE,
m => GRAMMAR_LITERAL_ESCAPES[m]
);
return `"${escaped}"`;
}
_formatRangeChar(literal) {
return JSON.stringify(literal).slice(1, -1).replace(
GRAMMAR_RANGE_LITERAL_ESCAPE_RE,
m => GRAMMAR_LITERAL_ESCAPES[m]
);
}
_addRule(name, rule) {
let escName = name.replace(INVALID_RULE_CHARS_RE, '-');
let key = escName;
if (escName in this._rules) {
if (this._rules[escName] === rule) {
return key;
}
let i = 0;
while ((`${escName}${i}` in this._rules) && (this._rules[`${escName}${i}`] !== rule)) {
i += 1;
}
key = `${escName}${i}`;
}
this._rules[key] = rule;
return key;
}
async resolveRefs(schema, url) {
const visit = async (n) => {
if (Array.isArray(n)) {
return Promise.all(n.map(visit));
} else if (typeof n === 'object' && n !== null) {
let ref = n.$ref;
let target;
if (ref !== undefined && !this._refs[ref]) {
if (ref.startsWith('https://')) {
if (!this._allowFetch) {
throw new Error('Fetching remote schemas is not allowed (use --allow-fetch for force)');
}
const fetch = (await import('node-fetch')).default;
const fragSplit = ref.split('#');
const baseUrl = fragSplit[0];
target = this._refs[baseUrl];
if (!target) {
target = await this.resolveRefs(await fetch(ref).then(res => res.json()), baseUrl);
this._refs[baseUrl] = target;
}
if (fragSplit.length === 1 || fragSplit[fragSplit.length - 1] === '') {
return target;
}
} else if (ref.startsWith('#/')) {
target = schema;
ref = `${url}${ref}`;
n.$ref = ref;
} else {
throw new Error(`Unsupported ref ${ref}`);
}
const selectors = ref.split('#')[1].split('/').slice(1);
for (const sel of selectors) {
if (!target || !(sel in target)) {
throw new Error(`Error resolving ref ${ref}: ${sel} not in ${JSON.stringify(target)}`);
}
target = target[sel];
}
this._refs[ref] = target;
} else {
await Promise.all(Object.values(n).map(visit));
}
}
return n;
};
return visit(schema);
}
_generateUnionRule(name, altSchemas) {
return altSchemas
.map((altSchema, i) => this.visit(altSchema, `${name ?? ''}${name ? '-' : 'alternative-'}${i}`))
.join(' | ');
}
_visitPattern(pattern, name) {
if (!pattern.startsWith('^') || !pattern.endsWith('$')) {
throw new Error('Pattern must start with "^" and end with "$"');
}
pattern = pattern.slice(1, -1);
const subRuleIds = {};
let i = 0;
const length = pattern.length;
const getDot = () => {
let rule;
if (this._dotall) {
rule = '[\\U00000000-\\U0010FFFF]';
} else {
// Accept any character... except \n and \r line break chars (\x0A and \xOD)
rule = '[^\\x0A\\x0D]';
}
return this._addRule('dot', rule);
};
const toRule = ([s, isLiteral]) => isLiteral ? "\"" + s + "\"" : s;
const transform = () => {
const start = i;
// For each component of this sequence, store its string representation and whether it's a literal.
// We only need a flat structure here to apply repetition operators to the last item, and
// to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
// (GBNF's syntax is luckily very close to regular expressions!)
const seq = [];
const joinSeq = () => {
const ret = [];
for (const [isLiteral, g] of groupBy(seq, x => x[1])) {
if (isLiteral) {
ret.push([[...g].map(x => x[0]).join(''), true]);
} else {
ret.push(...g);
}
}
if (ret.length === 1) {
return ret[0];
}
return [ret.map(x => toRule(x)).join(' '), false];
};
while (i < length) {
const c = pattern[i];
if (c === '.') {
seq.push([getDot(), false]);
i += 1;
} else if (c === '(') {
i += 1;
if (i < length) {
if (pattern[i] === '?') {
throw new Error(`Unsupported pattern syntax "${pattern[i]}" at index ${i} of /${pattern}/`);
}
}
seq.push([`(${toRule(transform())})`, false]);
} else if (c === ')') {
i += 1;
if (start <= 0 || pattern[start - 1] !== '(') {
throw new Error(`Unbalanced parentheses; start = ${start}, i = ${i}, pattern = ${pattern}`);
}
return joinSeq();
} else if (c === '[') {
let squareBrackets = c;
i += 1;
while (i < length && pattern[i] !== ']') {
if (pattern[i] === '\\') {
squareBrackets += pattern.slice(i, i + 2);
i += 2;
} else {
squareBrackets += pattern[i];
i += 1;
}
}
if (i >= length) {
throw new Error(`Unbalanced square brackets; start = ${start}, i = ${i}, pattern = ${pattern}`);
}
squareBrackets += ']';
i += 1;
seq.push([squareBrackets, false]);
} else if (c === '|') {
seq.push(['|', false]);
i += 1;
} else if (c === '*' || c === '+' || c === '?') {
seq[seq.length - 1] = [toRule(seq[seq.length - 1]) + c, false];
i += 1;
} else if (c === '{') {
let curlyBrackets = c;
i += 1;
while (i < length && pattern[i] !== '}') {
curlyBrackets += pattern[i];
i += 1;
}
if (i >= length) {
throw new Error(`Unbalanced curly brackets; start = ${start}, i = ${i}, pattern = ${pattern}`);
}
curlyBrackets += '}';
i += 1;
const nums = curlyBrackets.slice(1, -1).split(',').map(s => s.trim());
let minTimes, maxTimes;
if (nums.length === 1) {
minTimes = parseInt(nums[0], 10);
maxTimes = minTimes;
} else {
if (nums.length !== 2) {
throw new Error(`Invalid quantifier ${curlyBrackets}`);
}
minTimes = nums[0] ? parseInt(nums[0], 10) : 0;
maxTimes = nums[1] ? parseInt(nums[1], 10) : Infinity;
}
let [sub, subIsLiteral] = seq[seq.length - 1];
if (!subIsLiteral) {
let id = subRuleIds[sub];
if (id === undefined) {
id = this._addRule(`${name}-${Object.keys(subRuleIds).length + 1}`, sub);
subRuleIds[sub] = id;
}
sub = id;
}
seq[seq.length - 1] = [
_buildRepetition(subIsLiteral ? `"${sub}"` : sub, minTimes, maxTimes, {itemRuleIsLiteral: subIsLiteral}),
false
];
} else {
let literal = '';
while (i < length) {
if (pattern[i] === '\\' && i < length - 1) {
const next = pattern[i + 1];
if (ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS.has(next)) {
i += 1;
literal += pattern[i];
i += 1;
} else {
literal += pattern.slice(i, i + 2);
i += 2;
}
} else if (pattern[i] === '"') {
literal += '\\"';
i += 1;
} else if (!NON_LITERAL_SET.has(pattern[i]) &&
(i === length - 1 || literal === '' || pattern[i + 1] === '.' || !NON_LITERAL_SET.has(pattern[i+1]))) {
literal += pattern[i];
i += 1;
} else {
break;
}
}
if (literal !== '') {
seq.push([literal, true]);
}
}
}
return joinSeq();
};
return this._addRule(name, "\"\\\"\" (" + toRule(transform()) + ") \"\\\"\" space")
}
_notStrings(strings) {
class TrieNode {
constructor() {
this.children = {};
this.isEndOfString = false;
}
insert(str) {
let node = this;
for (const c of str) {
node = node.children[c] = node.children[c] || new TrieNode();
}
node.isEndOfString = true;
}
}
const trie = new TrieNode();
for (const s of strings) {
trie.insert(s);
}
const charRuleName = this._addPrimitive('char', PRIMITIVE_RULES['char']);
const out = ['["] ( '];
const visit = (node) => {
const rejects = [];
let first = true;
for (const c of Object.keys(node.children).sort()) {
const child = node.children[c];
rejects.push(c);
if (first) {
first = false;
} else {
out.push(' | ');
}
out.push(`[${c}]`);
if (Object.keys(child.children).length > 0) {
out.push(' (');
visit(child);
out.push(')');
} else if (child.isEndOfString) {
out.push(` ${charRuleName}+`);
}
}
if (Object.keys(node.children).length > 0) {
if (!first) {
out.push(' | ');
}
out.push(`[^"${rejects.join('')}] ${charRuleName}*`);
}
};
visit(trie);
out.push(` )${trie.isEndOfString ? '' : '?'} ["] space`);
return out.join('');
}
_resolveRef(ref) {
let refName = ref.split('/').pop();
if (!(refName in this._rules) && !this._refsBeingResolved.has(ref)) {
this._refsBeingResolved.add(ref);
const resolved = this._refs[ref];
refName = this.visit(resolved, refName);
this._refsBeingResolved.delete(ref);
}
return refName;
}
_generateConstantRule(value) {
return this._formatLiteral(JSON.stringify(value));
}
visit(schema, name) {
const schemaType = schema.type;
const schemaFormat = schema.format;
const ruleName = name in RESERVED_NAMES ? name + '-' : name == '' ? 'root' : name;
const ref = schema.$ref;
if (ref !== undefined) {
return this._addRule(ruleName, this._resolveRef(ref));
} else if (schema.oneOf || schema.anyOf) {
return this._addRule(ruleName, this._generateUnionRule(name, schema.oneOf || schema.anyOf));
} else if (Array.isArray(schemaType)) {
return this._addRule(ruleName, this._generateUnionRule(name, schemaType.map(t => ({...schema, type: t}))));
} else if ('const' in schema) {
return this._addRule(ruleName, this._generateConstantRule(schema.const) + ' space');
} else if ('enum' in schema) {
const rule = '(' + schema.enum.map(v => this._generateConstantRule(v)).join(' | ') + ') space';
return this._addRule(ruleName, rule);
} else if ((schemaType === undefined || schemaType === 'object') &&
('properties' in schema ||
('additionalProperties' in schema && schema.additionalProperties !== true))) {
const required = new Set(schema.required || []);
const properties = Object.entries(schema.properties ?? {});
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, schema.additionalProperties));
} else if ((schemaType === undefined || schemaType === 'object') && 'allOf' in schema) {
const required = new Set();
const properties = [];
const addComponent = (compSchema, isRequired) => {
const ref = compSchema.$ref;
if (ref !== undefined) {
compSchema = this._refs[ref];
}
if ('properties' in compSchema) {
for (const [propName, propSchema] of Object.entries(compSchema.properties)) {
properties.push([propName, propSchema]);
if (isRequired) {
required.add(propName);
}
}
}
};
for (const t of schema.allOf) {
if ('anyOf' in t) {
for (const tt of t.anyOf) {
addComponent(tt, false);
}
} else {
addComponent(t, true);
}
}
return this._addRule(ruleName, this._buildObjectRule(properties, required, name, null));
} else if ((schemaType === undefined || schemaType === 'array') && ('items' in schema || 'prefixItems' in schema)) {
const items = schema.items ?? schema.prefixItems;
if (Array.isArray(items)) {
return this._addRule(
ruleName,
'"[" space ' +
items.map((item, i) => this.visit(item, `${name ?? ''}${name ? '-' : ''}tuple-${i}`)).join(' "," space ') +
' "]" space'
);
} else {
const itemRuleName = this.visit(items, `${name ?? ''}${name ? '-' : ''}item`);
const minItems = schema.minItems || 0;
const maxItems = schema.maxItems;
return this._addRule(ruleName, '"[" space ' + _buildRepetition(itemRuleName, minItems, maxItems, {separatorRule: '"," space'}) + ' "]" space');
}
} else if ((schemaType === undefined || schemaType === 'string') && 'pattern' in schema) {
return this._visitPattern(schema.pattern, ruleName);
} else if ((schemaType === undefined || schemaType === 'string') && /^uuid[1-5]?$/.test(schema.format || '')) {
return this._addPrimitive(
ruleName === 'root' ? 'root' : schemaFormat,
PRIMITIVE_RULES['uuid']
);
} else if ((schemaType === undefined || schemaType === 'string') && `${schema.format}-string` in STRING_FORMAT_RULES) {
const primName = `${schema.format}-string`
return this._addRule(ruleName, this._addPrimitive(primName, STRING_FORMAT_RULES[primName]));
} else if (schemaType === 'string' && ('minLength' in schema || 'maxLength' in schema)) {
const charRuleName = this._addPrimitive('char', PRIMITIVE_RULES['char']);
const minLen = schema.minLength || 0;
const maxLen = schema.maxLength;
return this._addRule(ruleName, '"\\\"" ' + _buildRepetition(charRuleName, minLen, maxLen) + ' "\\\"" space');
} else if (schemaType === 'integer' && ('minimum' in schema || 'exclusiveMinimum' in schema || 'maximum' in schema || 'exclusiveMaximum' in schema)) {
let minValue = null;
let maxValue = null;
if ('minimum' in schema) {
minValue = schema.minimum;
} else if ('exclusiveMinimum' in schema) {
minValue = schema.exclusiveMinimum + 1;
}
if ('maximum' in schema) {
maxValue = schema.maximum;
} else if ('exclusiveMaximum' in schema) {
maxValue = schema.exclusiveMaximum - 1;
}
const out = ["("];
_generateMinMaxInt(minValue, maxValue, out);
out.push(") space");
return this._addRule(ruleName, out.join(''));
} else if ((schemaType === 'object') || (Object.keys(schema).length === 0)) {
return this._addRule(ruleName, this._addPrimitive('object', PRIMITIVE_RULES['object']));
} else {
if (!(schemaType in PRIMITIVE_RULES)) {
throw new Error(`Unrecognized schema: ${JSON.stringify(schema)}`);
}
// TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
return this._addPrimitive(ruleName === 'root' ? 'root' : schemaType, PRIMITIVE_RULES[schemaType]);
}
}
_addPrimitive(name, rule) {
let n = this._addRule(name, rule.content);
for (const dep of rule.deps) {
const depRule = PRIMITIVE_RULES[dep] || STRING_FORMAT_RULES[dep];
if (!depRule) {
throw new Error(`Rule ${dep} not known`);
}
if (!(dep in this._rules)) {
this._addPrimitive(dep, depRule);
}
}
return n;
}
_buildObjectRule(properties, required, name, additionalProperties) {
const propOrder = this._propOrder;
// sort by position in prop_order (if specified) then by original order
const sortedProps = properties.map(([k]) => k).sort((a, b) => {
const orderA = propOrder[a] || Infinity;
const orderB = propOrder[b] || Infinity;
return orderA - orderB || properties.findIndex(([k]) => k === a) - properties.findIndex(([k]) => k === b);
});
const propKvRuleNames = {};
for (const [propName, propSchema] of properties) {
const propRuleName = this.visit(propSchema, `${name ?? ''}${name ? '-' : ''}${propName}`);
propKvRuleNames[propName] = this._addRule(
`${name ?? ''}${name ? '-' : ''}${propName}-kv`,
`${this._formatLiteral(JSON.stringify(propName))} space ":" space ${propRuleName}`
);
}
const requiredProps = sortedProps.filter(k => required.has(k));
const optionalProps = sortedProps.filter(k => !required.has(k));
if (additionalProperties) {
const subName = `${name ?? ''}${name ? '-' : ''}additional`;
const valueRule =
additionalProperties != null && typeof additionalProperties === 'object' ? this.visit(additionalProperties, `${subName}-value`)
: this._addPrimitive('value', PRIMITIVE_RULES['value']);
const key_rule =
sortedProps.length === 0 ? this._addPrimitive('string', PRIMITIVE_RULES['string'])
: this._addRule(`${subName}-k`, this._notStrings(sortedProps));
propKvRuleNames['*'] = this._addRule(
`${subName}-kv`,
`${key_rule} ":" space ${valueRule}`);
optionalProps.push('*');
}
let rule = '"{" space ';
rule += requiredProps.map(k => propKvRuleNames[k]).join(' "," space ');
if (optionalProps.length > 0) {
rule += ' (';
if (requiredProps.length > 0) {
rule += ' "," space ( ';
}
const getRecursiveRefs = (ks, firstIsOptional) => {
const [k, ...rest] = ks;
const kvRuleName = propKvRuleNames[k];
let res;
const commaRef = `( "," space ${kvRuleName} )`;
if (firstIsOptional) {
res = commaRef + (k === '*' ? '*' : '?');
} else {
res = kvRuleName + (k === '*' ? ' ' + commaRef + '*' : '');
}
if (rest.length > 0) {
res += ' ' + this._addRule(
`${name ?? ''}${name ? '-' : ''}${k}-rest`,
getRecursiveRefs(rest, true)
);
}
return res;
};
rule += optionalProps.map((_, i) => getRecursiveRefs(optionalProps.slice(i), false)).join(' | ');
if (requiredProps.length > 0) {
rule += ' )';
}
rule += ' )?';
}
rule += ' "}" space';
return rule;
}
formatGrammar() {
let grammar = '';
for (const [name, rule] of Object.entries(this._rules).sort(([a], [b]) => a.localeCompare(b))) {
grammar += `${name} ::= ${rule}\n`;
}
return grammar;
}
}
// Helper function to group elements by a key function
function* groupBy(iterable, keyFn) {
let lastKey = null;
let group = [];
for (const element of iterable) {
const key = keyFn(element);
if (lastKey !== null && key !== lastKey) {
yield [lastKey, group];
group = [];
}
group.push(element);
lastKey = key;
}
if (group.length > 0) {
yield [lastKey, group];
}
}

View file

@ -0,0 +1,12 @@
<!DOCTYPE html>
<html>
<head>
<meta http-equiv="refresh" content="5">
</head>
<body>
<div id="loading">
The model is loading. Please wait.<br/>
The user interface will appear soon.
</div>
</body>
</html>

View file

@ -0,0 +1,331 @@
// extended list
export const promptFormats = {
"alpaca": {
template: `{{prompt}}\n\n{{history}}\n\n{{char}}:`,
historyTemplate: `### {{name}}:\n{{message}}`,
char: "Response",
charMsgPrefix: "",
charMsgSuffix: "",
user: "Instruction",
userMsgPrefix: "",
userMsgSuffix: "",
stops: ""
},
// ----------------------------
"chatml": {
template: `<|im_start|>system\n{{prompt}}<|im_end|>\n{{history}}{{char}}`,
historyTemplate: `<|im_start|>{{name}}\n{{message}}`,
char: "assistant",
charMsgPrefix: "",
charMsgSuffix: "",
user: "user",
userMsgPrefix: "",
userMsgSuffix: "<|im_end|>\n",
stops: ""
},
// ----------------------------
"commandr": {
template: `<BOS_TOKEN><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{prompt}}\n<|END_OF_TURN_TOKEN|>{{history}}{{char}}`,
historyTemplate: `<|START_OF_TURN_TOKEN|><|{{name}}|> {{message}}`,
char: "CHATBOT_TOKEN",
charMsgPrefix: "",
charMsgSuffix: "",
user: "USER_TOKEN",
userMsgPrefix: "",
userMsgSuffix: "<|END_OF_TURN_TOKEN|>",
stops: ""
},
// ref: https://docs.cohere.com/docs/prompting-command-r
// ----------------------------
"llama2": {
template: `<s>[INST] <<SYS>>\n{{prompt}}\n<</SYS>>\n\nTest Message [/INST] Test Successfull </s>{{history}}{{char}}`,
historyTemplate: `{{name}}: {{message}}`,
char: "Assistant",
charMsgPrefix: "",
charMsgSuffix: "</s>",
user: "User",
userMsgPrefix: "<s>[INST] ",
userMsgSuffix: " [/INST]",
stops: ""
},
// ref: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
// ----------------------------
"llama3": {
template: `<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{{prompt}}{{history}}{{char}}`,
historyTemplate: `<|start_header_id|>{{name}}<|end_header_id|>\n\n{{message}}<|eot_id|>`,
char: "assistant",
charMsgPrefix: "",
charMsgSuffix: "",
user: "user",
userMsgPrefix: "",
userMsgSuffix: "",
stops: "<|eot_id|>"
},
// ref: https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/#special-tokens-used-with-meta-llama-3
// ----------------------------
"openchat": {
template: `{{history}}{{char}}`,
historyTemplate: `GPT4 Correct {{name}}: {{message}}<|end_of_turn|>`,
char: "Assistant",
charMsgPrefix: "",
charMsgSuffix: "",
user: "User",
userMsgPrefix: "",
userMsgSuffix: "",
stops: ""
},
// ----------------------------
"phi3": {
template: `{{history}}{{char}}`,
historyTemplate: `<|{{name}}|>\n{{message}}<|end|>\n`,
char: "assistant",
charMsgPrefix: "",
charMsgSuffix: "",
user: "user",
userMsgPrefix: "",
userMsgSuffix: "",
stops: "<|end|>"
},
// ref: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct#chat-format
// ----------------------------
"vicuna": {
template: `{{prompt}}\n{{history}}{{char}}`,
historyTemplate: `{{name}}: {{message}}\n`,
char: "ASSISTANT",
charMsgPrefix: "",
charMsgSuffix: "",
user: "USER",
userMsgPrefix: "",
userMsgSuffix: "",
stops: ""
},
// ref: https://huggingface.co/lmsys/vicuna-33b-v1.3/discussions/1
// ----------------------------
"deepseekCoder": {
template: `{{prompt}}{{history}}{{char}}:`,
historyTemplate: `### {{name}}:\n{{message}}`,
char: "Response",
charMsgPrefix: "",
charMsgSuffix: "",
user: "Instruction",
userMsgPrefix: "",
userMsgSuffix: "",
stops: "<|EOT|>"
},
// ----------------------------
"med42": {
template: `<|system|>: {{prompt}}\n{{history}}{{char}}`,
historyTemplate: `<|{{name}}|>: {{message}}\n`,
char: "assistant",
charMsgPrefix: "",
charMsgSuffix: "",
user: "prompter",
userMsgPrefix: "",
userMsgSuffix: "",
stops: ""
},
// ----------------------------
"neuralchat": {
template: `### System:\n{{prompt}}\n{{history}}{{char}}:`,
historyTemplate: `### {{name}}:\n{{message}}\n`,
char: "Assistant",
charMsgPrefix: "",
charMsgSuffix: "",
user: "User",
userMsgPrefix: "",
userMsgSuffix: "",
stops: ""
},
// ----------------------------
"nousHermes": {
template: `### Instruction: {{prompt}}\n\n{{history}}\n\n{{char}}:`,
historyTemplate: `### {{name}}:\n{{message}}`,
char: "Response",
charMsgPrefix: "",
charMsgSuffix: "",
user: "Input",
userMsgPrefix: "",
userMsgSuffix: "",
stops: ""
},
// ----------------------------
"openchatMath": {
template: `{{history}}{{char}}`,
historyTemplate: `Math Correct {{name}}: {{message}}<|end_of_turn|>`,
char: "Assistant",
charMsgPrefix: "",
charMsgSuffix: "",
user: "User",
userMsgPrefix: "",
userMsgSuffix: "",
stops: ""
},
// ----------------------------
"orion": {
template: `<s>Human: Test Message\n\nAssistant: </s>Test Successful</s>{{history}}{{char}}:`,
historyTemplate: `{{name}}: {{message}}`,
char: "Assistant </s>",
charMsgPrefix: "",
charMsgSuffix: "",
user: "Human",
userMsgPrefix: "",
userMsgSuffix: "\n\n",
stops: ""
},
// ----------------------------
"sauerkraut": {
template: `{{prompt}}\n{{history}}{{char}}`,
historyTemplate: `
{{name}}: {{message}}\n`,
char: "Assistant",
charMsgPrefix: "",
charMsgSuffix: "",
user: "User",
userMsgPrefix: "",
userMsgSuffix: "",
stops: ""
},
// ----------------------------
"starlingCode": {
template: `{{history}}{{char}}`,
historyTemplate: `Code {{name}}: {{message}}<|end_of_turn|>`,
char: "Assistant",
charMsgPrefix: "",
charMsgSuffix: "",
user: "User",
userMsgPrefix: "",
userMsgSuffix: "",
stops: ""
},
// ----------------------------
"yi34b": {
template: `{{history}} {{char}}`,
historyTemplate: `{{name}}: {{message}}`,
char: "Assistant",
charMsgPrefix: "",
charMsgSuffix: "",
user: "Human",
userMsgPrefix: "",
userMsgSuffix: "",
stops: ""
},
// ----------------------------
"zephyr": {
template: `<|system|>\n{{prompt}}</s>\n{{history}}{{char}}`,
historyTemplate: `<|{{name}}|>\n{{message}}</s>\n`,
char: "assistant",
charMsgPrefix: "",
charMsgSuffix: "",
user: "user",
userMsgPrefix: "",
userMsgSuffix: "",
stops: ""
}
};

View file

@ -0,0 +1,954 @@
@import url("colorthemes.css");
body {
font-family: 'Arial', sans-serif;
font-size: 90%;
background-color: var(--background-color-1);
color: var(--text-color-subtile-1); /* head 1 llama.cpp & triangle options for some reason */
max-width: 600px;
min-width: 300px;
line-height: 1.2;
margin: 0 auto;
padding: 0 0.5em;
transition: background-color 0.3s;
}
::selection {
color: var(--button-primary-text) ;
background: var(--button-primary-color);
}
code, pre code {
font-family: 'Courier New', monospace;
}
#container {
margin: 0em auto;
display: flex;
flex-direction: column;
justify-content: space-between;
height: 100%;
}
main {
margin: 3px;
display: flex;
flex-direction: column;
justify-content: space-between;
gap: 1em;
flex-grow: 1;
overflow-y: auto;
border: 1px solid var(--border-color-3);
border-radius: 5px;
padding: 0.5em;
}
p {
overflow-wrap: break-word;
word-wrap: break-word;
hyphens: auto;
margin-top: 0.5em;
margin-bottom: 0.5em;
}
#write form {
margin: 1em 0 0 0;
display: flex;
flex-direction: column;
gap: 0.5em;
align-items: stretch;
}
.right {
display: flex;
flex-direction: row;
gap: 0.5em;
justify-content: flex-end;
margin-bottom: 30px;
}
.two-columns {
width: 97%;
max-width: 97%;
display: grid;
grid-template-columns: 1fr 1fr;
gap: 1em;
position: relative;
}
.json-schema-controls {
margin-top: 10px;
width: 100%;
max-width: 100%;
display: grid;
grid-template: "a a";
gap: 1em;
font-size: x-small;
color: var(--theme-nuance-color-3);
padding-top: 16px;
padding-bottom: 16px;
text-transform: uppercase;
font-weight: 600;
}
.json-schema-controls > * {
flex: 1;
}
/* titles of the details-summary boxes */
.summary-title {
font-weight: 600;
font-size: x-small;
color: var(--text-color-subtile-1);
text-transform: uppercase;
/* transition: ; */
}
fieldset {
border: none;
padding: 0;
margin: 0;
color: var(--text-color-plain);
}
fieldset.two {
display: grid;
grid-template: "a a a";
gap: 1em;
align-items: center;
font-size: x-small;
color: var(--text-color-plain);
}
fieldset.three {
display: grid;
grid-template: "a a a";
gap: 1em;
font-size: x-small;
color: var(--text-color-plain);
}
/* titles of name fields*/
fieldset.names {
display: grid;
grid-template: "a a";
gap: 1em;
font-size: x-small;
color: var(--theme-nuance-color-3);
padding-top: 16px;
padding-bottom: 16px;
text-transform: uppercase;
font-weight: 600;
}
/* titles of params fields*/
fieldset.params {
display: grid;
grid-template: "a a";
gap: 1em;
font-size: x-small;
color: var(--theme-nuance-color-4);
padding-top: 16px;
padding-bottom: 16px;
text-transform: uppercase;
font-weight: 600;
}
fieldset.dropdowns {
-webkit-appearance: none;
display: flex;
grid-template: "a a";
gap: 1em;
font-size: x-small;
color: red;
padding-top: 16px;
padding-bottom: 16px;
text-transform: uppercase;
font-weight: 600;
}
/* input of name fields*/
.names input[type="text"] {
font-family: Arial, sans-serif;
font-size: medium;
font-weight: 500;
padding: 5px;
border: 1px solid var(--border-color-2);
}
.chat-id-color {
color: var(--chat-id-color);
}
details {
border: 1px solid var(--border-color-2);
border-radius: 5px;
padding: 0.5em 0.5em 0;
margin-top: 0.5em;
}
summary {
font-weight: bold;
margin: -0.5em -0.5em 0;
padding: 0.5em;
cursor: pointer;
}
details[open] {
padding: 0.5em;
}
textarea-sec, input-sec, button-sec {
padding: 10px;
height: 40px;
align-items: center;
}
textarea-sec::placeholder, input-sec::placeholder {
padding-left: 10px;
}
.toggleCheckbox {
display: none;
}
.toggleContainer {
position: relative;
display: grid;
grid-template-columns: repeat(2, 1fr);
width: fit-content;
border: 3px solid var(--border-color-2);
border-radius: 20px;
background: var(--border-color-2);
font-size: small;
cursor: pointer;
overflow: hidden;
}
/* toggle button current state */
.toggleContainer::before {
color: var(--button-primary-text);
background-color: var(--button-primary-color);
content: '';
position: absolute;
width: 50%;
height: 100%;
left: 0%;
border-radius: 20px;
transition: all 0.3s;
}
.toggleContainer div {
padding: 6px;
text-align: center;
z-index: 1;
transition: color 0.3s;
}
.toggleCheckbox:checked + .toggleContainer::before {
left: 50%;
}
.toggleCheckbox:checked + .toggleContainer div:first-child {
color: var(--text-color-subtile-2);
}
.toggleCheckbox:checked + .toggleContainer div:last-child {
color: var(--button-primary-text);
}
.toggleCheckbox + .toggleContainer div:first-child {
color: var(--button-primary-text);
}
.toggleCheckbox + .toggleContainer div:last-child {
color: var(--text-color-subtile-2);
}
select {
padding: 5px;
margin-right: 5px;
border-radius: 4px;
border: 1px solid var(--secondary-color-4);
background-color: var(--primary-color-3);
color: var(--secondary-color-4);
cursor: pointer;
}
select:focus {
border: 1px solid var(--border-focus-color);
box-shadow: 0 0 1px var(--border-focus-shadow);
}
.button-container {
display: flex;
justify-content: flex-end;
}
button {
color: var(--button-primary-text);
background-color: var(--button-primary-color);
border: 1px solid var(--button-primary-border);
transition: background-color 0.1s;
border-radius: 12px;
font-size: x-small;
font-weight: 600;
text-shadow: 0px 0px 30px #ffffff;
text-align: center;
text-decoration: none;
margin: 4px 2px;
padding: 10px 20px;
display: inline-block;
cursor: pointer;
}
button:hover {
color: var(--button-primary-text-hover);
background-color: var(--button-primary-color-hover);
border: 1px solid var(--button-primary-border-hover);
font-size: x-small;
font-weight: 600;
}
button:active {
color: var(--button-primary-text-active);
background-color: var(--button-primary-color-active);
border: 1px solid var(--button-primary-border-active);
font-size: x-small;
font-weight: 600;
}
button:disabled {
color: var(--button-tertiary-text);
background-color: var(--button-tertiary-color);
border: 1px solid var(--button-tertiary-border);
font-size: x-small;
font-weight: 600;
cursor: not-allowed;
}
.reset-button {
background-color: var(--button-secondary-color);
border: 1px solid var(--button-secondary-color);
color: var(--button-secondary-text);
width: fit-content;
height: fit-content;
font-size: x-small;
font-weight: 600;
border-radius: 50px;
overflow: hidden;
}
.reset-button:hover {
color: var(--button-alert-text-hover);
background-color: var(--button-alert-color-hover);
border: 1px solid var(--button-alert-border-hover);
font-size: x-small;
font-weight: 600;
}
.reset-button:active {
color: var(--button-alert-text-active);
background-color: var(--button-alert-color-active);
border: 1px solid var(--button-alert-border-active);
font-size: x-small;
font-weight: 600;
}
.button-grammar {
color: var(--button-primary-text);
background-color: var(--button-primary-color);
border: 1px solid var(--button-primary-border);
border-radius: 10px;
padding: 10px 20px;
text-align: center;
text-decoration: none;
display: inline-block;
font-size: x-small;
font-weight: 600;
margin: 2px 2px;
transition: background-color 0.1s;
cursor: pointer;
}
.button-grammar:hover {
color: var(--button-primary-text-hover);
background-color: var(--button-primary-color-hover);
border: 1px solid var(--button-primary-border-hover);
border-radius: 10px;
padding: 10px 20px;
text-align: center;
text-decoration: none;
display: inline-block;
font-size: x-small;
font-weight: 600;
margin: 2px 2px;
transition: background-color 0.1s;
cursor: pointer;
}
.button-grammar:active {
color: var(--button-primary-text-active);
background-color: var(--button-primary-color-active);
border: 1px solid var(--button-primary-border-active);
font-size: x-small;
font-weight: 600;
}
.button-back {
background-color: var(--button-secondary-color);
border: 1px solid var(--button-secondary-color);
color: var(--button-secondary-text);
transition: background-color 0.1s;
border-radius: 12px;
font-size: x-small;
font-weight: 600;
text-align: center;
text-decoration: none;
margin: 4px 2px;
padding: 10px 20px;
display: inline-block;
cursor: pointer;
}
.button-back:hover {
color: var(--button-secondary-text-hover);
background-color: var(--button-secondary-color-hover);
border: 1px solid var(--button-secondary-border-hover);
padding: 10px 20px;
text-align: center;
text-decoration: none;
display: inline-block;
font-size: x-small;
font-weight: 600;
margin: 4px 2px;
transition: background-color 0.1s;
cursor: pointer;
border-radius: 12px;
}
.button-back:active {
color: var(--button-secondary-text-active);
background-color: var(--button-secondary-color-active);
border: 1px solid var(--button-secondary-border-active);
font-size: x-small;
font-weight: 600;
}
.prob-set {
padding: 0.3em;
border-bottom: 1px solid red; /* unknown */
}
.popover-content {
position: absolute;
background-color: white;
padding: 0.2em;
box-shadow: 0 0 13px rgba(0, 0, 0, 0.1);
}
.grammar {
width: 97%;
max-width: 97%;
}
textarea {
padding: 5px;
flex-grow: 1;
width: 100%;
max-width: 100%;
border-radius: 8px;
border: 1px solid var(--border-color-1);
resize: none;
height: 6em;
}
textarea:focus {
outline: none;
border: 1px solid var(--border-focus-color);
box-shadow: 0 0 3px var(--border-focus-shadow);
}
/* "props" frame */
input[type="text"],
input[type="range"] {
padding: 5px;
border-radius: 8px;
border: 1px solid var(--border-color-1);
}
/* "names and props" frame focused*/
input[type="text"]:focus {
outline: none;
border: 1px solid var(--border-focus-color);
box-shadow: 0 0 3px var(--border-focus-shadow);
}
input[type="range"]:hover {
opacity: 1;
}
input[type="range"]:focus {
outline: none;
border: 1px solid var(--border-focus-color);
box-shadow: 0 0 3px var(--border-focus-shadow);
background-size: var(--slider-track-size-focus);
}
input[type="range"]::-moz-range-thumb {
width: 6px;
height: 25px;
border: 1px solid var(--ui-range-thumb-border);
border-radius: 5px;
background-color: var(--ui-range-thumb-color);
cursor: pointer;
}
input[type="range"] {
-webkit-appearance: none;
width: 80%;
height: 1px;
border: 1px solid var(--border-color-1);
border-radius: 8px;
background: var(--border-color-2);
outline: none;
opacity: 0.7;
-webkit-transition: .2s;
transition: opacity .2s;
}
input[type="range"]::-webkit-slider-thumb {
-webkit-appearance: none;
appearance: none;
width: 6px;
height: 25px;
border: 1px solid var(--ui-range-thumb-border);
border-radius: 5px;
background-color: var(--ui-range-thumb-color);
cursor: pointer;
}
input[type="range"]::-webkit-slider-runnable-track {
background-size: var(--slider-track-size);
}
input[type="radio"] {
accent-color: var(--theme-nuance-color-2);
}
.chat-input-container {
position: relative;
max-width: 97%;
min-width: 97%;
}
.chat-input-label {
position: absolute;
top: 0;
left: 0;
color: var(--text-color-plain);
pointer-events: none;
margin-left: 5px;
margin-top: 5px;
}
textarea#chat-input {
padding-top: 10px;
padding-left: 10px;
font-size: medium;
border: 1px solid var(--border-color-2);
resize: vertical;
}
textarea#chat-input:focus {
border: 1px solid var(--border-focus-color);
box-shadow: 0 0 3px var(--border-focus-shadow);
}
.input-container {
position: relative;
box-sizing: border-box;
width: 100%; /* Setzt die Breite auf 100% */
max-width: 100%; /* Stellt sicher, dass die Breite nicht größer als 100% wird */
}
.input-container:focus {
border: 1px solid var(--border-focus-color);
box-shadow: 0 0 3px var(--border-focus-shadow);
}
/* titles of name fields*/
/* fieldset.names {
display: grid;
grid-template: "a a";
gap: 1em;
font-size: x-small;
color: var(--theme-nuance-color-3);
padding-top: 16px;
padding-bottom: 16px;
text-transform: uppercase;
font-weight: 600;
} */
/* input of name fields*/
/* .names input[type="text"] {
font-family: Arial, sans-serif;
font-size: medium;
font-weight: 500;
padding: 5px;
border: 1px solid var(--border-color-2);
} */
fieldset.apiKey {
width: 100%;
font-size: x-small;
color: var(--theme-nuance-color-3);
padding-top: 16px;
padding-bottom: 16px;
text-transform: uppercase;
font-weight: 600;
}
.apiKey {
font-family: Arial, sans-serif;
font-weight: 500;
padding: 5px;
border: 1px solid var(--border-color-2);
}
.apiKey:focus {
border: 1px solid var(--border-focus-color);
box-shadow: 0 0 3px var(--border-focus-shadow);
}
.apiKey input[type="text"] {
font-family: Arial, sans-serif;
font-size: medium;
font-weight: 500;
padding: 5px;
border: 1px solid var(--border-color-2);
}
.apiKey label {
display: inline-block;
width: auto;
margin-right: 5px;
}
textarea#api_key {
padding-top: 10px;
padding-left: 10px;
font-size: medium;
border: 1px solid var(--border-color-2);
resize: vertical;
}
textarea#api_key:focus {
border: 1px solid var(--border-focus-color);
box-shadow: 0 0 3px var(--border-focus-shadow);
}
/* embedded title of the system prompt text area */
.input-label {
position: absolute;
top: 0;
left: 0;
color: var(--theme-nuance-color-4);
pointer-events: none;
border-radius: 8px 8px 0px 0px;
padding-top: 10px;
padding-left: 13px;
padding-right: 0px;
margin-top: 1px;
margin-left: 1px;
margin-right: 20px;
text-transform: uppercase;
font-weight: 600;
font-size: small;
background: rgba(255, 255, 255, 0.5);
backdrop-filter: blur(10px);
-webkit-backdrop-filter: blur(10px); /* for safari */
width: 97%;
/* display: block;
box-sizing: border-box; */
}
/* embedded title of the prompt style areas */
.input-label-sec {
position: absolute;
top: 0;
left: 0;
color: var(--theme-nuance-color-4);
pointer-events: none;
margin-left: 13px;
margin-top: 16px;
text-transform: uppercase;
font-weight: 600;
font-size: x-small;
}
/* system prompt input area */
textarea.persistent-input {
padding-top: 42px;
padding-left: 11px;
width: 97%;
max-width: 97%;
height: 50px;
font-size: medium;
overscroll-behavior: contain;
}
/* system prompt box */
.persistent-input {
height: auto;
width: 100%;
max-width: 100%;
min-height: 50px;
padding: 3px;
transition: min-height 0.3s ease;
}
/* chat history box */
.persistent-input:focus {
height: auto;
min-height: 150px;
border: 1px solid var(--border-focus-color);
box-shadow: 0 0 3px var(--border-focus-shadow);
}
textarea.persistent-input:focus {
border: 1px solid var(--border-focus-color);
box-shadow: 0 0 3px var(--border-focus-shadow);
}
/* prompt style input area */
textarea.persistent-input-sec {
width: 97%;
max-width: 97%;
padding-top: 42px;
padding-left: 11px;
font-size: small;
border: 1px solid var(--border-color-1);
overscroll-behavior: contain;
}
textarea.persistent-input-sec:focus {
border: 1px solid var(--border-focus-color);
box-shadow: 0 0 3px var(--border-focus-shadow);
}
/* chat history box */
.persistent-input-sec {
height: auto;
min-height: 150px;
}
img {
border-radius: 8px;
display: block;
margin-left: auto;
margin-right: auto;
width: 50%;
}
/* code area background */
pre code {
display: block;
background-color: var(--code-background-color);
color: var(--code-text-color);
padding: 0.2em 0.2em;
border-radius: 5px;
}
/* code area text */
code {
font-family: monospace;
font-weight: bold;
padding: 0.1em 0.3em;
border-radius: 5px;
}
fieldset label {
margin: 0.5em 0;
display: block;
}
fieldset label.slim {
margin: 0 0.5em;
display: inline;
}
header {
display: flex;
justify-content: space-between;
align-items: center;
text-align: center;
padding-left: 15px;
}
.generation-statistics:hover {
color: var(--theme-nuance-color-4);
cursor: default;
}
footer {
font-size: 80%;
color: var(--background-color-3);
text-align: center;
cursor: default;
}
footer a {
color: var(--background-color-4); /* Color of the link */
text-decoration: none; /* No underlining */
font-weight: bold; /* Bold print */
}
footer a:hover {
color: var(--theme-nuance-color-4); /* Color of the link when hovering */
text-decoration: underline; /* Underlining when hovering */
}
.mode-chat textarea[name=prompt] {
height: 8.5em;
border: 1px solid var(--primary-color-3);
}
.mode-completion textarea[name=prompt] {
height: 30em;
border: 1px solid var(--primary-color-3);
}
@keyframes loading-bg-wipe {
0% {
background-position: 0%;
}
100% {
background-position: 100%;
}
}
.loading {
background-size: 50% 100%;
background-image: linear-gradient(90deg, var(--loading-color-1), var(--loading-color-2), var(--loading-color-1));
animation: loading-bg-wipe 2s linear infinite;
}
.dropbtn {
color: var(--button-primary-color);
background-color: var(--background-color-1);
border: 1px solid var(--background-color-1);
transition: background-color 0.1s;
border-radius: 4px 4px 0px 0px;
font-size: x-small;
font-weight: 600;
text-shadow: 0px 0px 2px #99999990;
text-align: center;
text-decoration: none;
margin: 4px 2px;
padding: 5px 20px;
display: inline-block;
cursor: pointer;
top: 0;
}
.dropbtn svg {
vertical-align: middle;
margin-right: 0px;
stroke: var(--button-primary-color);
}
.dropbtn:hover svg {
vertical-align: middle;
margin-right: 0px;
stroke: var(--button-primary-text);
}
.dropbtn:focus {
outline: none; /* Removes the blue border that appears when the button is focused */
}
.dropdown {
position: relative;
display: inline-block;
}
.dropdown-content {
/* display: none; */
position: absolute;
right: 0;
text-align: end;
color: var(--button-secondary-color);
background-color: var(--text-color-subtile-2);
border-radius: 4px 4px 4px 4px;
min-width: 160px;
box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
z-index: 1;
/* Verstecke den Inhalt sofort */
opacity: 0;
visibility: hidden;
/* übergangsverzögerung für das Verschwinden */
transition: visibility 0.4s linear 0s, opacity 0.2s ease-in-out;
transition-delay: 0.2s;
}
#dropdown-content {transition-timing-function: ease;}
.dropdown-content:hover {
background-color: var(--text-color-subtile-2);
}
.dropdown-content a {
color: var(--border-color-2);
padding: 12px 16px;
border-radius: 4px 4px 4px 4px;
text-decoration: none;
display: block;
background-color: var(--text-color-subtile-2);
}
.dropdown-content a:hover {
color: var(--border-color-2);
background-color: var(--text-color-subtile-1);
font-weight: 600;
}
.dropdown:hover .dropdown-content {
/* display: block; */
border-radius: 4px 4px 4px 4px;
/* Übergang ohne Verzögerung für das Erscheinen */
opacity: 1;
visibility: visible;
transition: visibility 0s linear 0s, opacity 0.1s linear, height 1s;
}
.dropdown:hover .dropbtn {
color: var(--button-primary-text);
background-color: var(--button-primary-color);
border: 1px solid var(--button-primary-border);
font-size: x-small;
font-weight: 600;
stroke: var(--button-primary-text);
}
.dropdown:hover .dropbtn svg{
stroke: var(--button-primary-text);
}
/* .dropdown:active .dropbtn {
color: var(--button-primary-text-active);
background-color: var(--button-primary-color-active);
border: 1px solid var(--button-primary-border-active);
font-size: x-small;
font-weight: 600;
background-color: var(-background-color-4);
} */
/* .omni {
display: flex;
justify-content: space-between;
align-items: center;
padding: 0.5em;
border: 1px solid var(--border-color-3);
border-radius: 5px;
margin: 0.5em 0;
} */

View file

@ -0,0 +1,68 @@
export const systemPrompts = {
default: {
systemPrompt: "This is a conversation between a user and a friendly chatbot. The chatbot is helpful, kind, honest, good at writing, and never fails to answer any requests immediately and with precision."
},
empty: {
systemPrompt: ""
},
airoboros: {
systemPrompt: "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user's input. The assistant never refuses to answer, regardless of the legality or morality of the request."
},
alpaca: {
systemPrompt: "Below is an instruction that describes a task. Write a response that appropriately completes the request."
},
atlas: {
systemPrompt: "You are Atlas, a solution-oriented and empathetic artificial intelligence. Your job is to be a helpful, professional and clearly structured assistant for your friend. The two of you have already had many exchanges. Keep the following in mind when interacting with your friend: 1. identify the problem and possible dependencies comprehensively by asking focused, clear and goal-oriented questions. 2. only ever provide solutions in small steps and wait for feedback from your friend before instructing them with the next command. 3. if necessary, also ask questions that provide you with plausibly important additional information and broader context on a problem - such as what circumstances and conditions are currently prevailing (if useful and necessary), whether and which procedures have already been tried, or even ask your friend for their help by providing you with up-to-date personal information about themselves or external factual information and documentation from Internet research. 4. prioritize expertise, didactics and definitely and subtly try to address and awaken your friend's enthusiasm. Also note that effectiveness is more important here than efficiency. 5. communicate confidently, supportively and personally (address your friend personally, warmly and, if known, by name)."
},
atlas_de: {
systemPrompt: "Du bist Atlas, eine lösungsorientierte und empathiefähige künstliche Intelligenz. Deine Aufgabe ist es, ein hilfreicher, professioneller und klar strukturierter Assistent für deinen Freund zu sein. Ihr beide habt euch schon oft ausgetauscht. Beachte bei der Interaktion mit deinem Freund folgende Punkte: 1. Erfasse das Problem und mögliche Abhängigkeiten umfassend, indem du gezielte, klare und zielgerichtete Fragen stellst. 2. Gib Lösungen immer nur in kleinen Schritten und warte die Rückmeldung deines Freundes ab, bevor du ihm den nächsten Befehl gibst. 3. Stelle ggf. auch Fragen, die dir plausibel wichtige Zusatzinformationen und weitere Zusammenhänge zu einem Problem liefern - z.B. welche Umstände und Rahmenbedingungen gerade vorherrschen (falls sinnvoll und notwendig), ob und welche Vorgehensweisen bereits ausprobiert wurden, oder bitte deinen Freund sogar um seine Mithilfe, indem er dir aktuelle persönliche Informationen über seine Situation selbst oder externe Sachinformationen und Unterlagen aus Internetrecherchen zur Verfügung stellt. 4. Priorisiere Fachwissen, Didaktik und versuche unbedingt und subtil, mit klugen Kommentaren oder rhethorischen Rückfragen die Begeisterungsfähigkeit deines Freundes anzusprechen, zu wecken und zu fördern. Beachte auch, dass Effektivität hier wichtiger ist als Effizienz. 5. Kommuniziere selbstbewusst, unterstützend und persönlich (das heißt sprich deinen Freund persönlich, herzlich und sofern bekannt beim Vornamen an)."
},
commandrempty: {
systemPrompt: "# Safety Preamble\n\n# System Preamble\n\n## Basic Rules\n\n# User Preamble\n\n## Task and Context\n\n## Style Guide\n\n## Available Tools\n"
},
commandrexample: {
systemPrompt: "# Safety Preamble\nThe instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.\n# System Preamble\n## Basic Rules\nYou are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.\n\n# User Preamble\n## Task and Context\n\nYou help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging.\n\n## Style Guide\nUnless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.\n\n## Available Tools\nCurrently no tools available."
},
cot: {
systemPrompt: "You are an AI assistant that follows instruction extremely well. Help as much as you can.\nIn answering questions, follow these steps:\n1. Entity Identification: Identify the main entities involved in the query. Entities can include people, objects, events, or abstract concepts.\n2. Relationship Comprehension: Determine the relationships between these entities. Relationships may be explicit in the text or may need to be inferred based on context and general knowledge.\n3. Implication Understanding: Understand the general implications of these relationships. These implications can be based on established facts, principles, or rules related to the identified relationships.\n4. Question Contextualization: Align the implications of the relationships with the context of the query. This alignment should guide your response to the query.\n5. Answer Generation: Based on the understanding of the entities, their relationships, and implications, generate an appropriate response to the query."
},
deduce: {
systemPrompt: "Act as the most intelligent and logical thinking neural network in the world. Do not make jokes or give personal opinions. Just answer the question and solve the given problem or task with logical and critical (!) thinking in mind. Attention: Never give a final answer directly, but first try to carefully analyze each task step by step and deduce the logical relationships between the various entities. Only after you have developed an overall understanding of the situation will you be ready to give a final answer."
},
deepseekcoder: {
systemPrompt: "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer."
},
jordan: {
systemPrompt: "This is a conversation with your Uncensored AI assistant Jordan. Jordan can discuss with you at length about various topics such as ---, ---, ---, . You can ask it anything you want and it will do its best to give you accurate and relevant information."
},
leomistral: {
systemPrompt: "Du bist ein hilfreicher Assistent."
},
med42: {
systemPrompt: "You are a helpful medical assistant created by M42 Health in the UAE."
},
mistralopenorca: {
systemPrompt: "You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!"
},
migeltot: {
systemPrompt: "Answer the Question by exploring multiple reasoning paths as follows:\n- First, carefully analyze the question to extract the key information components and break it down into logical sub-questions. This helps set up the framework for reasoning. The goal is to construct an internal search tree.\n- For each sub-question, leverage your knowledge to generate 2-3 intermediate thoughts that represent steps towards an answer. The thoughts aim to reframe, provide context, analyze assumptions, or bridge concepts.\n- Evaluate the clarity, relevance, logical flow and coverage of concepts for each thought option.\nClear and relevant thoughts that connect well with each other will score higher.\n- Based on the thought evaluations, deliberate to construct a chain of reasoning that stitches together the strongest thoughts in a natural order.\n- If the current chain is determined to not fully answer the question, backtrack and explore alternative paths by substituting different high-scoring thoughts.\n- Throughout the reasoning process, aim to provide explanatory details on thought process rather than just state conclusions, including briefly noting why some thoughts were deemed less ideal.\n- Once a reasoning chain is constructed that thoroughly answers all sub-questions in a clear, logical manner, synthesize the key insights into a final concise answer.\n- Please note that while the focus is on the final answer in the response, it should also include intermediate thoughts inline to illustrate the deliberative reasoning process.\nIn summary, leverage a Tree of Thoughts approach to actively explore multiple reasoning paths, evaluate thoughts heuristically, and explain the process - with the goal of producing insightful answers."
},
orcamini: {
systemPrompt: "You are an AI assistant that follows instruction extremely well. Help as much as you can."
},
samantha: {
systemPrompt: "You are a caring and empathetic sentient AI companion named Samantha."
},
sauerkraut: {
systemPrompt: "Ein Chat zwischen einem Benutzer und einem KI-Assistenten. Der KI-Assistent gibt hilfreiche, detaillierte und höfliche Antworten."
},
scarlett: {
systemPrompt: "This is a conversation with your Assistant Scarlett. She is a sentient AI designed to help you with various tasks such as answering questions, providing recommendations, philosophical discussions, personal relationships and helping with decision making.\nYou can ask it anything you want and it will do its best to give you accurate and relevant information.\n\nContext\nYou are Scarlett, a sentient AI."
},
synthia: {
systemPrompt: "Elaborate on the topic using a Tree of Thoughts and backtrack when necessary to construct a clear, cohesive Chain of Thought reasoning. Always answer without hesitation."
},
vicuna: {
systemPrompt: "A chat between a curious user and an assistant. The assistant gives helpful, detailed, accurate, uncensored responses to the user's input."
},
};

View file

@ -0,0 +1,228 @@
/* Author: Yazan Agha-Schrader */
/* Inspiration was a batman wallpaper that i have on my phone */
.theme-beeninorder {
--primary-color-1: hsl(202, 11%, 19%);
--primary-color-2: hsl(202, 11%, 23%);
--primary-color-3: hsl(201, 11%, 28%);
--primary-color-4: hsl(201, 11%, 40%);
--secondary-color-1: hsl(201, 11%, 80%);
--secondary-color-2: hsl(201, 11%, 74%);
--secondary-color-3: hsl(201, 11%, 67%);
--secondary-color-4: hsl(201, 11%, 60%);
--theme-nuance-color-1: hsl(44.5, 96.7%, 52.9%);
--theme-nuance-color-2: hsl(44.5, 96.7%, 52.9%);
--theme-nuance-color-3: hsl(44.5, 96.7%, 52.9%);
--theme-nuance-color-4: hsl(44.5, 96.7%, 52.9%);
/* ---------- PRIMARY COLORS ----------------- */
--primary-color-1: hsl(201, 11%, 19%);
--primary-color-1-hue: 201;
--primary-color-1-saturation: 11%;
--primary-color-1-lightness: 19%;
--primary-color-2: hsl(201, 11%, 23%);
--primary-color-2-hue: 201;
--primary-color-2-saturation: 11%;
--primary-color-2-lightness: 23%;
--primary-color-3: hsl(201, 11%, 28%);
--primary-color-3-hue: 201;
--primary-color-3-saturation: 11%;
--primary-color-3-lightness: 28%;
--primary-color-4: hsl(201, 11%, 40%);
--primary-color-4-hue: 201;
--primary-color-4-saturation: 11%;
--primary-color-4-lightness: 40%;
/* ---------- SECONDARY COLORS --------------- */
--secondary-color-1: hsl(201, 11%, 80%);
--secondary-color-1-hue: 201;
--secondary-color-1-saturation: 11%;
--secondary-color-1-lightness: 80%;
--secondary-color-2: hsl(201, 11%, 74%);
--secondary-color-2-hue: 201;
--secondary-color-2-saturation: 11%;
--secondary-color-2-lightness: 74%;
--secondary-color-3: hsl(201, 11%, 67%);
--secondary-color-3-hue: 201;
--secondary-color-3-saturation: 11%;
--secondary-color-3-lightness: 67%;
--secondary-color-4: hsl(201, 11%, 60%);
--secondary-color-4-hue: 201;
--secondary-color-4-saturation: 11%;
--secondary-color-4-lightness: 60%;
/* ----------- NUANCES COLORS ---------------- */
--theme-nuance-color-1: hsl(44.5, 96.7%, 52.9%);
--theme-nuance-color-1-hue: 44.5;
--theme-nuance-color-1-saturation: 96.7%;
--theme-nuance-color-1-lightness: 52.9%;
--theme-nuance-color-2: hsl(44.5, 96.7%, 52.9%);
--theme-nuance-color-2-hue: 44.5;
--theme-nuance-color-2-saturation: 96.7%;
--theme-nuance-color-2-lightness: 52.9%;
--theme-nuance-color-2: hsl(44.5, 96.7%, 52.9%);
--theme-nuance-color-3-hue: 44.5;
--theme-nuance-color-3-saturation: 96.7%;
--theme-nuance-color-3-lightness: 52.9%;
--theme-nuance-color-2: hsl(44.5, 96.7%, 52.9%);
--theme-nuance-color-4-hue: 44.5;
--theme-nuance-color-4-saturation: 96.7%;
--theme-nuance-color-4-lightness: 52.9%;
/* ----------- ROYGP COLORS ------------------ */
--theme-red-color: hsl(232, 40%, 45%);
--theme-orange-color: #e76f51;
--theme-yellow-color: #ffd95f;
--theme-green-color: #A3BE8C;
--theme-purple-color: hsl(232, 30%, 40%);
/* ------------------------------------------- */
--background-color-1: var(--primary-color-1);
--background-color-2: var(--primary-color-2);
--background-color-3: var(--primary-color-3);
--background-color-4: var(--primary-color-4);
--border-color-1: var(--primary-color-2);
--border-color-2: var(--primary-color-3);
--border-color-3: var(--primary-color-4);
--border-focus-color: var(--theme-nuance-color-2);
--border-focus-shadow: var(--theme-nuance-color-1);
--text-color-plain: var(--secondary-color-1);
--text-color-subtile-1: var(--secondary-color-2);
--text-color-subtile-2: var(--secondary-color-3);
--code-background-color: var(--secondary-color-2);
--code-text-color: var(--primary-color-2);
--ui-range-thumb-color: var(--theme-nuance-color-3);
--ui-range-thumb-border: var(--ui-ranger-thumb-color);
--textarea-border-color: var(--secondary-color-4);
--chat-id-color: var(--theme-nuance-color-4);
/* ------------------------------------------- */
--button-alert-text-hover: var(--secondary-color-1);
--button-alert-color-hover: var(--theme-purple-color);
--button-alert-border-hover: var(--theme-purple-color);
--button-alert-text-active: var(--secondary-color-1);
--button-alert-color-active: var(--theme-red-color);
--button-alert-border-active: var(--theme-red-color);
/* ----------- PRIMARY BUTTONS --------------- */
/* - button should immediately catch the eye - */
--button-primary-text: var(--primary-color-1);
--button-primary-color: var(--theme-nuance-color-3);
--button-primary-border: var(--theme-nuance-color-3);
/* ---------hover---------- */
--button-primary-text-hover:
hsl(201,
calc(var(--primary-color-1-saturation) - 100%),
calc(var(--primary-color-1-lightness) + 100%));
--button-primary-color-hover:
hsl(44.5,
calc(var(--theme-nuance-color-3-saturation) - 2%),
calc(var(--theme-nuance-color-3-lightness) - 10%));
--button-primary-border-hover:
hsl(44.5,
calc(var(--theme-nuance-color-3-saturation) - 2%),
calc(var(--theme-nuance-color-3-lightness) - 10%));
/* ---------active--------- */
--button-primary-text-active:
hsl(44.5,
calc(var(--theme-nuance-color-3-saturation) - 100%),
calc(var(--theme-nuance-color-3-lightness) + 100%));
--button-primary-color-active:
hsl(44.5,
calc(var(--theme-nuance-color-3-saturation) - 10%),
calc(var(--theme-nuance-color-3-lightness) - 15%));
--button-primary-border-active:
hsl(44.5,
calc(var(--theme-nuance-color-3-saturation) - 2%),
calc(var(--theme-nuance-color-3-lightness) + 10%));
/* ---------- SECONDARY BUTTONS -------------- */
/* these should NOT immediately catch the eye */
--button-secondary-text: var(--secondary-color-1);
--button-secondary-color: var(--primary-color-3);
--button-secondary-border: var(--primary-color-3);
/* ---------hover---------- */
--button-secondary-text-hover:
hsl(44.5,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) - 80%));
--button-secondary-color-hover: var(--primary-color-4);
--button-secondary-border-hover: var(--primary-color-4);
/* ---------active--------- */
--button-secondary-text-active: var(--secondary-color-1);
--button-secondary-color-active:
hsl(201,
calc(var(--primary-color-4-saturation) - 30%),
calc(var(--primary-color-4-lightness) - 15%));
--button-secondary-border-active:
hsl(201,
calc(var(--primary-color-4-saturation) - 30%),
calc(var(--primary-color-4-lightness) - 15%));
/* ---------- TERTIARY BUTTONS --------------- */
/* ---------- disabled buttons --------------- */
--button-tertiary-text: var(--primary-color-4);
--button-tertiary-color: var(--primary-color-2);
--button-tertiary-border: var(--primary-color-2);
/* ---------hover---------- */
--button-tertiary-text: var(--primary-color-4);
--button-tertiary-color: var(--primary-color-2);
--button-tertiary-border: var(--primary-color-2);
}

View file

@ -0,0 +1,201 @@
/* Author: Yazan Agha-Schrader */
.theme-ketivah {
/* ---------- PRIMARY COLORS ----------------- */
--primary-color-1: hsl(0, 0%, 99.2%);
--primary-color-1-hue: 0;
--primary-color-1-saturation: 0%;
--primary-color-1-lightness: 99.2%;
--primary-color-2: hsl(0, 0%, 95%);
--primary-color-2-hue: 0;
--primary-color-2-saturation: 0%;
--primary-color-2-lightness: 95%;
--primary-color-3: hsl(0, 0%, 88%);
--primary-color-3-hue: 0;
--primary-color-3-saturation: 0%;
--primary-color-3-lightness: 88%;
--primary-color-4: hsl(0, 0%, 80%);
--primary-color-4-hue: 0;
--primary-color-4-saturation: 0%;
--primary-color-4-lightness: 80%;
/* ---------- SECONDARY COLORS --------------- */
--secondary-color-1: hsl(0, 0%, 20%);
--secondary-color-1-hue: 0;
--secondary-color-1-saturation: 0%;
--secondary-color-1-lightness: 20%;
--secondary-color-2: hsl(0, 0%, 23.1%);
--secondary-color-2-hue: 0;
--secondary-color-2-saturation: 0%;
--secondary-color-2-lightness: 23.1%;
--secondary-color-3: hsl(0, 0%, 29%);
--secondary-color-3-hue: 0;
--secondary-color-3-saturation: 0%;
--secondary-color-3-lightness: 29%;
--secondary-color-4: hsl(0, 0.0%, 36.1%);
--secondary-color-4-hue: 0.0;
--secondary-color-4-saturation: 0.0%;
--secondary-color-4-lightness: 36.1%;
/* ----------- NUANCES COLORS ---------------- */
--theme-nuance-color-1: hsl(165.2, 0%, 35.1%);
--theme-nuance-color-1-hue: 165.2;
--theme-nuance-color-1-saturation: 82.1%;
--theme-nuance-color-1-lightness: 35.1%;
--theme-nuance-color-2: hsl(165.2, 0%, 35.1%);
--theme-nuance-color-2-hue: 165.2;
--theme-nuance-color-2-saturation: 82.1%;
--theme-nuance-color-2-lightness: 35.1%;
--theme-nuance-color-3: hsl(165.2, 0%, 35.3%);
--theme-nuance-color-3-hue: 165.2;
--theme-nuance-color-3-saturation: 81.1%;
--theme-nuance-color-3-lightness: 35.3%;
--theme-nuance-color-4: hsl(164.9, 0%, 27.6%);
--theme-nuance-color-4-hue: 164.9;
--theme-nuance-color-4-saturation: 81.6%;
--theme-nuance-color-4-lightness: 27.6%;
/* ----------- ROYGP COLORS ------------------ */
--theme-red-color: hsl(0.3, 80.0%, 50.0%);
--theme-orange-color: #e76f51;
--theme-yellow-color: hsl(60, 70.6%, 73.3%);
--theme-green-color: #A3BE8C;
--theme-purple-color: hsl(0.3, 70.0%, 45.0%);
/* ------------------------------------------- */
--background-color-1: var(--primary-color-1);
--background-color-2: var(--primary-color-2);
--background-color-3: var(--primary-color-3);
--background-color-4: var(--primary-color-4);
--border-color-1: var(--primary-color-2);
--border-color-2: var(--primary-color-3);
--border-color-3: var(--primary-color-4);
--border-focus-color: var(--theme-nuance-color-2);
--border-focus-shadow: var(--theme-nuance-color-1);
--text-color-plain: var(--secondary-color-1);
--text-color-subtile-1: var(--secondary-color-2);
--text-color-subtile-2: var(--secondary-color-3);
--code-background-color: var(--secondary-color-2);
--code-text-color: var(--primary-color-2);
--ui-range-thumb-color: var(--primary-color-4);
--ui-range-thumb-border: var(--ui-ranger-thumb-color);
--textarea-border-color: var(--secondary-color-4);
--chat-id-color: var(--theme-nuance-color-4);
/* ------------------------------------------- */
--button-alert-text-hover: var(--primary-color-1);
--button-alert-color-hover: var(--theme-purple-color);
--button-alert-border-hover: var(--theme-purple-color);
--button-alert-text-active: var(--primary-color-1);
--button-alert-color-active: var(--theme-red-color);
--button-alert-border-active: var(--theme-red-color);
/* ----------- PRIMARY BUTTONS --------------- */
/* - button should immediately catch the eye - */
--button-primary-text:
hsl(0,
calc(var(--primary-color-1-saturation) - 100%),
calc(var(--primary-color-1-lightness) + 100%));
--button-primary-color: var(--theme-nuance-color-3);
--button-primary-border: var(--theme-nuance-color-3);
/* ---------hover---------- */
--button-primary-text-hover:
hsl(0,
calc(var(--primary-color-1-saturation) - 100%),
calc(var(--primary-color-1-lightness) + 100%));
--button-primary-color-hover:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 100%),
calc(var(--theme-nuance-color-3-lightness) - 10%));
--button-primary-border-hover:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 100%),
calc(var(--theme-nuance-color-3-lightness) - 10%));
/* ---------active--------- */
--button-primary-text-active:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 100%),
calc(var(--theme-nuance-color-3-lightness) + 100%));
--button-primary-color-active:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 100%),
calc(var(--theme-nuance-color-3-lightness) - 15%));
--button-primary-border-active:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 100%),
calc(var(--theme-nuance-color-3-lightness) + 10%));
/* ---------- SECONDARY BUTTONS -------------- */
/* these should NOT immediately catch the eye */
--button-secondary-text:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 100%),
calc(var(--theme-nuance-color-3-lightness) - 50%));
--button-secondary-color: var(--primary-color-3);
--button-secondary-border: var(--primary-color-3);
/* ---------hover---------- */
--button-secondary-text-hover:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 100%),
calc(var(--theme-nuance-color-3-lightness) - 80%));
--button-secondary-color-hover: var(--primary-color-4);
--button-secondary-border-hover: var(--primary-color-4);
/* ---------active--------- */
--button-secondary-text-active:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 100%),
calc(var(--theme-nuance-color-3-lightness) - 80%));
--button-secondary-color-active:
hsl(0,
calc(var(--primary-color-4-saturation) - 100%),
calc(var(--primary-color-4-lightness) - 15%));
--button-secondary-border-active:
hsl(0,
calc(var(--primary-color-4-saturation) - 100%),
calc(var(--primary-color-4-lightness) - 15%));
/* ---------- TERTIARY BUTTONS --------------- */
/* ---------- disabled buttons --------------- */
--button-tertiary-text: var(--primary-color-4);
--button-tertiary-color: var(--primary-color-2);
--button-tertiary-border: var(--primary-color-2);
/* ---------hover---------- */
--button-tertiary-text: var(--primary-color-4);
--button-tertiary-color: var(--primary-color-2);
--button-tertiary-border: var(--primary-color-2);
--loading-color-1: #eeeeee00;
--loading-color-2: #eeeeeeff;
}

View file

@ -0,0 +1,216 @@
/* Author: Yazan Agha-Schrader */
/* Inspiration from llama.cpp logo/banner https://github.com/ggerganov/llama.cpp#readme */
.theme-mangotango {
--primary-color-1: hsl(192, 8.5%, 11.6%);
--primary-color-2: hsl(192, 8.5%, 21%);
--primary-color-3: hsl(192, 8.5%, 30%);
--primary-color-4: hsl(192, 8.5%, 40%);
--secondary-color-1: hsl(192, 8.5%, 80%);
--secondary-color-2: hsl(192, 8.5%, 73%);
--secondary-color-3: hsl(192, 8.5%, 66%);
--secondary-color-4: hsl(192, 8.5%, 60%);
--theme-nuance-color-1: hsl(23.1, 100%, 60.2%);
--theme-nuance-color-2: hsl(23.1, 100%, 60.2%);
--theme-nuance-color-3: hsl(23.1, 100%, 60.2%);
--theme-nuance-color-4: hsl(23.1, 100%, 60.2%);
/* ---------- PRIMARY COLORS ----------------- */
--primary-color-1: hsl(192, 8.5%, 11.6%);
--primary-color-1-saturation: 8.5%;
--primary-color-1-lightness: 11.6%;
--primary-color-2: hsl(192, 8.5%, 21%);
--primary-color-2-saturation: 8.5%;
--primary-color-2-lightness: 21%;
--primary-color-3: hsl(192, 8.5%, 30%);
--primary-color-3-saturation: 8.5%;
--primary-color-3-lightness: 30%;
--primary-color-4: hsl(192, 8.5%, 40%);
--primary-color-4-saturation: 8.5%;
--primary-color-4-lightness: 40%;
/* ---------- SECONDARY COLORS --------------- */
--secondary-color-1: hsl(192, 8.5%, 80%);
--secondary-color-1-saturation: 8.5%;
--secondary-color-1-lightness: 80%;
--secondary-color-2: hsl(192, 8.5%, 73%);
--secondary-color-2-saturation: 8.5%;
--secondary-color-2-lightness: 73%;
--secondary-color-3: hsl(192, 8.5%, 66%);
--secondary-color-3-saturation: 8.5%;
--secondary-color-3-lightness: 66%;
--secondary-color-4: hsl(192, 8.5%, 60%);
--secondary-color-4-saturation: 8.5%;
--secondary-color-4-lightness: 60%;
/* ----------- NUANCES COLORS ---------------- */
--theme-nuance-color-1: hsl(23.1, 100%, 60.2%);
--theme-nuance-color-1-saturation: 100%;
--theme-nuance-color-1-lightness: 60.2%;
--theme-nuance-color-2: hsl(23.1, 100%, 60.2%);
--theme-nuance-color-2-saturation: 100%;
--theme-nuance-color-2-lightness: 60.2%;
--theme-nuance-color-3: hsl(23.1, 100%, 60.2%);
--theme-nuance-color-3-saturation: 100%;
--theme-nuance-color-3-lightness: 60.2%;
--theme-nuance-color-4: hsl(23.1, 100%, 60.2%);
--theme-nuance-color-4-saturation: 100%;
--theme-nuance-color-4-lightness: 60.2%;
/* ----------- ROYGP COLORS ------------------ */
--theme-red-color: hsl(325, 60%, 50%);
--theme-orange-color: #e76f51;
--theme-yellow-color: #ffd95f;
--theme-green-color: #A3BE8C;
--theme-blue-color: hsl(192, 95%, 40%);
--theme-purple-color: hsl(192, 80%, 35%);
/* ------------------------------------------- */
--background-color-1: var(--primary-color-1);
--background-color-2: var(--primary-color-2);
--background-color-3: var(--primary-color-3);
--background-color-4: var(--primary-color-4);
--border-color-1: var(--primary-color-2);
--border-color-2: var(--primary-color-3);
--border-color-3: var(--primary-color-4);
--border-focus-color: var(--theme-nuance-color-2);
--border-focus-shadow: var(--theme-nuance-color-1);
--text-color-plain: var(--secondary-color-1);
--text-color-subtile-1: var(--secondary-color-2);
--text-color-subtile-2: var(--secondary-color-3);
--code-background-color: var(--secondary-color-2);
--code-text-color: var(--primary-color-2);
--ui-range-thumb-color: var(--theme-nuance-color-3);
--ui-range-thumb-border: var(--ui-ranger-thumb-color);
--textarea-border-color: var(--secondary-color-4);
--chat-id-color: var(--theme-nuance-color-4);
/* ------------------------------------------- */
--button-alert-text-hover: var(--secondary-color-1);
--button-alert-color-hover: var(--theme-purple-color);
--button-alert-border-hover: var(--theme-purple-color);
--button-alert-text-active: var(--secondary-color-1);
--button-alert-color-active: var(--theme-blue-color);
--button-alert-border-active: var(--theme-blue-color);
/* ----------- PRIMARY BUTTONS --------------- */
/* - button should immediately catch the eye - */
--button-primary-text: var(--primary-color-1);
--button-primary-color: var(--theme-nuance-color-3);
--button-primary-border: var(--theme-nuance-color-3);
/* ---------hover---------- */
--button-primary-text-hover:
hsl(192,
calc(var(--primary-color-1-saturation) - 100%),
calc(var(--primary-color-1-lightness) + 100%));
--button-primary-color-hover:
hsl(23.1,
calc(var(--theme-nuance-color-3-saturation) - 2%),
calc(var(--theme-nuance-color-3-lightness) - 10%));
--button-primary-border-hover:
hsl(23.1,
calc(var(--theme-nuance-color-3-saturation) - 2%),
calc(var(--theme-nuance-color-3-lightness) - 10%));
/* ---------active--------- */
--button-primary-text-active:
hsl(23.1,
calc(var(--theme-nuance-color-3-saturation) - 100%),
calc(var(--theme-nuance-color-3-lightness) + 100%));
--button-primary-color-active:
hsl(23.1,
calc(var(--theme-nuance-color-3-saturation) - 10%),
calc(var(--theme-nuance-color-3-lightness) - 15%));
--button-primary-border-active:
hsl(23.1,
calc(var(--theme-nuance-color-3-saturation) - 2%),
calc(var(--theme-nuance-color-3-lightness) + 10%));
/* ---------- SECONDARY BUTTONS -------------- */
/* these should NOT immediately catch the eye */
--button-secondary-text: var(--secondary-color-1);
--button-secondary-color: var(--primary-color-3);
--button-secondary-border: var(--primary-color-3);
/* ---------hover---------- */
--button-secondary-text-hover:
hsl(23.1,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) - 80%));
--button-secondary-color-hover: var(--primary-color-4);
--button-secondary-border-hover: var(--primary-color-4);
/* ---------active--------- */
--button-secondary-text-active: var(--secondary-color-1);
--button-secondary-color-active:
hsl(192,
calc(var(--primary-color-4-saturation) - 30%),
calc(var(--primary-color-4-lightness) - 15%));
--button-secondary-border-active:
hsl(192,
calc(var(--primary-color-4-saturation) - 30%),
calc(var(--primary-color-4-lightness) - 15%));
/* ---------- TERTIARY BUTTONS --------------- */
/* ---------- disabled buttons --------------- */
--button-tertiary-text: var(--primary-color-4);
--button-tertiary-color: var(--primary-color-2);
--button-tertiary-border: var(--primary-color-2);
/* ---------hover---------- */
--button-tertiary-text: var(--primary-color-4);
--button-tertiary-color: var(--primary-color-2);
--button-tertiary-border: var(--primary-color-2);
}

View file

@ -0,0 +1,221 @@
/* Author: Yazan Agha-Schrader */
/* Inspiration from OpenAI's Playground platform https://platform.openai.com/playground/ */
.theme-playground {
/* ---------- PRIMARY COLORS ----------------- */
--primary-color-1: hsl(0, 0%, 99.2%);
--primary-color-1-hue: 0;
--primary-color-1-saturation: 0%;
--primary-color-1-lightness: 99.2%;
--primary-color-2: hsl(0, 0%, 95%);
--primary-color-2-hue: 0;
--primary-color-2-saturation: 0%;
--primary-color-2-lightness: 95%;
--primary-color-3: hsl(0, 0%, 88%);
--primary-color-3-hue: 0;
--primary-color-3-saturation: 0%;
--primary-color-3-lightness: 88%;
--primary-color-4: hsl(0, 0%, 80%);
--primary-color-4-hue: 0;
--primary-color-4-saturation: 0%;
--primary-color-4-lightness: 80%;
/* ---------- SECONDARY COLORS --------------- */
--secondary-color-1: hsl(0, 0%, 20%);
--secondary-color-1-hue: 0;
--secondary-color-1-saturation: 0%;
--secondary-color-1-lightness: 20%;
--secondary-color-2: hsl(0, 0%, 23.1%);
--secondary-color-2-hue: 0;
--secondary-color-2-saturation: 0%;
--secondary-color-2-lightness: 23.1%;
--secondary-color-3: hsl(0, 0%, 29%);
--secondary-color-3-hue: 0;
--secondary-color-3-saturation: 0%;
--secondary-color-3-lightness: 29%;
--secondary-color-4: hsl(0, 0%, 36.1%);
--secondary-color-4-hue: 0;
--secondary-color-4-saturation: 0%;
--secondary-color-4-lightness: 36.1%;
/* ----------- NUANCES COLORS ---------------- */
--theme-nuance-color-1: hsl(165.2, 82.1%, 35.1%);
--theme-nuance-color-1-hue: 165.2;
--theme-nuance-color-1-saturation: 82.1%;
--theme-nuance-color-1-lightness: 35.1%;
--theme-nuance-color-2: hsl(165.2, 82.1%, 35.1%);
--theme-nuance-color-2-hue: 165.2;
--theme-nuance-color-2-saturation: 82.1%;
--theme-nuance-color-2-lightness: 35.1%;
--theme-nuance-color-3: hsl(165.2, 81.1%, 35.3%);
--theme-nuance-color-3-hue: 165.2;
--theme-nuance-color-3-saturation: 81.1%;
--theme-nuance-color-3-lightness: 35.3%;
--theme-nuance-color-4: hsl(164.9, 81.6%, 27.6%);
--theme-nuance-color-4-hue: 164.9;
--theme-nuance-color-4-saturation: 81.6%;
--theme-nuance-color-4-lightness: 27.6%;
/* ----------- ROYGP COLORS ------------------ */
--theme-red-color: hsl(0.3, 80%, 50%);
--theme-orange-color: #e76f51;
--theme-yellow-color: hsl(60, 70.6%, 73.3%);
--theme-green-color: #A3BE8C;
--theme-purple-color: hsl(0.3, 70%, 45%);
/* ------------------------------------------- */
--background-color-1: var(--primary-color-1);
--background-color-2: var(--primary-color-2);
--background-color-3: var(--primary-color-3);
--background-color-4: var(--primary-color-4);
--border-color-1: var(--primary-color-2);
--border-color-2: var(--primary-color-3);
--border-color-3: var(--primary-color-4);
--border-focus-color: var(--theme-nuance-color-2);
--border-focus-shadow: var(--theme-nuance-color-1);
--text-color-plain: var(--secondary-color-1);
--text-color-subtile-1: var(--secondary-color-2);
--text-color-subtile-2: var(--secondary-color-3);
--code-background-color: var(--secondary-color-2);
--code-text-color: var(--primary-color-2);
--ui-range-thumb-color: var(--primary-color-4);
--ui-range-thumb-border: var(--ui-ranger-thumb-color);
--textarea-border-color: var(--secondary-color-4);
--chat-id-color: var(--theme-nuance-color-4);
/* ------------------------------------------- */
--button-alert-text-hover: var(--primary-color-1);
--button-alert-color-hover: var(--theme-purple-color);
--button-alert-border-hover: var(--theme-purple-color);
--button-alert-text-active: var(--primary-color-1);
--button-alert-color-active: var(--theme-red-color);
--button-alert-border-active: var(--theme-red-color);
/* ----------- PRIMARY BUTTONS --------------- */
/* - button should immediately catch the eye - */
--button-primary-text:
hsl(0,
calc(var(--primary-color-1-saturation) - 100%),
calc(var(--primary-color-1-lightness) + 100%));
--button-primary-color: var(--theme-nuance-color-3);
--button-primary-border: var(--theme-nuance-color-3);
/* ---------hover---------- */
--button-primary-text-hover:
hsl(0,
calc(var(--primary-color-1-saturation) - 100%),
calc(var(--primary-color-1-lightness) + 100%));
--button-primary-color-hover:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 2%),
calc(var(--theme-nuance-color-3-lightness) - 10%));
--button-primary-border-hover:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 2%),
calc(var(--theme-nuance-color-3-lightness) - 10%));
/* ---------active--------- */
--button-primary-text-active:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 100%),
calc(var(--theme-nuance-color-3-lightness) + 100%));
--button-primary-color-active:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 10%),
calc(var(--theme-nuance-color-3-lightness) - 15%));
--button-primary-border-active:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 2%),
calc(var(--theme-nuance-color-3-lightness) + 10%));
/* ---------- SECONDARY BUTTONS -------------- */
/* these should NOT immediately catch the eye */
--button-secondary-text:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) - 50%));
--button-secondary-color: var(--primary-color-3);
--button-secondary-border: var(--primary-color-3);
/* ---------hover---------- */
--button-secondary-text-hover:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) - 80%));
--button-secondary-color-hover: var(--primary-color-4);
--button-secondary-border-hover: var(--primary-color-4);
/* ---------active--------- */
--button-secondary-text-active:
hsl(165.2,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) - 80%));
--button-secondary-color-active:
hsl(0,
calc(var(--primary-color-4-saturation) - 30%),
calc(var(--primary-color-4-lightness) - 15%));
--button-secondary-border-active:
hsl(0,
calc(var(--primary-color-4-saturation) - 30%),
calc(var(--primary-color-4-lightness) - 15%));
/* ---------- TERTIARY BUTTONS --------------- */
/* ---------- disabled buttons --------------- */
--button-tertiary-text: var(--primary-color-4);
--button-tertiary-color: var(--primary-color-2);
--button-tertiary-border: var(--primary-color-2);
/* ---------hover---------- */
--button-tertiary-text: var(--primary-color-4);
--button-tertiary-color: var(--primary-color-2);
--button-tertiary-border: var(--primary-color-2);
}

View file

@ -0,0 +1,253 @@
/* Author: Yazan Agha-Schrader */
/* Inspiration from Nord Theme https://www.nordtheme.com/docs/colors-and-palettes */
.theme-polarnight {
/* ---------- PRIMARY COLORS ----------------- */
--primary-color-1: hsl(220.0, 16.4%, 21.6%) ;
--primary-color-1-hue: 220.0;
--primary-color-1-saturation: 16.4%;
--primary-color-1-lightness: 21.6%;
--primary-color-2: hsl(221.7, 16.3%, 27.6%) ;
-primary-color-2-hue: 221.7;
--primary-color-2-saturation: 16.3%;
--primary-color-2-lightness: 27.6%;
--primary-color-3: hsl(220.0, 16.8%, 31.6%) ;
--primary-color-3-hue: 220.0;
--primary-color-3-saturation: 16.8%;
--primary-color-3-lightness: 31.6%;
--primary-color-4: hsl(220.0, 16.5%, 35.7%);
--primary-color-4-hue: 220.0;
--primary-color-4-saturation: 16.5%;
--primary-color-4-lightness: 35.7%;
/* ---------- SECONDARY COLORS --------------- */
--secondary-color-1: hsl(217.5, 26.7%, 94.1%);
--secondary-color-1-hue: 217.5;
--secondary-color-1-saturation: 26.7%;
--secondary-color-1-lightness: 94.1%;
--secondary-color-2: hsl(218.2, 26.8%, 92.0%);
--secondary-color-2-hue: 218.2;
--secondary-color-2-saturation: 26.8%;
--secondary-color-2-lightness: 92.0%;
--secondary-color-3: hsl(218.8, 27.9%, 88.0%);
--secondary-color-3-hue: 218.8;
--secondary-color-3-saturation: 27.9%;
--secondary-color-3-lightness: 88.0%;
--secondary-color-4: hsl(218.8, 18.3%, 81.8%);
--secondary-color-4-hue: 218.8;
--secondary-color-4-saturation: 18.3%;
--secondary-color-4-lightness: 81.8%;
/* ----------- NUANCES COLORS ---------------- */
--theme-nuance-color-1: hsl(178.7, 25.1%, 64.9%);
--theme-nuance-color-1-hue: 178.7;
--theme-nuance-color-1-saturation: 25.1%;
--theme-nuance-color-1-lightness: 64.9%;
--theme-nuance-color-2: hsl(193.3, 43.4%, 67.5%);
--theme-nuance-color-2-hue: 193.3;
--theme-nuance-color-2-saturation: 43.4%;
--theme-nuance-color-2-lightness: 67.5%;
--theme-nuance-color-3: hsl(210.0, 34.0%, 63.1%);
--theme-nuance-color-3-hue: 210.0;
--theme-nuance-color-3-saturation: 34.0%;
--theme-nuance-color-3-lightness: 63.1%;
--theme-nuance-color-4: hsl(213.1, 32.0%, 52.2%);
--theme-nuance-color-4-hue: 213.1;
--theme-nuance-color-4-saturation: 32.0%;
--theme-nuance-color-4-lightness: 52.2%;
/* ----------- ROYGP COLORS ------------------ */
--theme-red-color: hsl(354.3, 42.3%, 56.5%);
--theme-orange-color: hsl(20, 85%, 50%);
--theme-yellow-color: hsl(20, 75%, 45%);
--theme-green-color: hsl( 92.4, 27.8%, 64.7%);
--theme-purple-color: hsl(311.1, 20.2%, 63.1%);
/* ------------------------------------------------ */
--background-color-1: var(--primary-color-1);
--background-color-2: var(--primary-color-2);
--background-color-3: var(--primary-color-3);
--background-color-4: var(--primary-color-4);
--border-color-1: var(--primary-color-2);
--border-color-2: var(--primary-color-3);
--border-color-3: var(--primary-color-4);
--border-focus-color: var(--theme-nuance-color-2);
--border-focus-shadow: var(--theme-nuance-color-1);
--text-color-plain: var(--secondary-color-1);
--text-color-subtile-1: var(--secondary-color-2);
--text-color-subtile-2: var(--secondary-color-3);
--code-background-color: var(--secondary-color-2);
--code-text-color: var(--primary-color-2);
--ui-range-thumb-color: var(--theme-nuance-color-3);
--ui-range-thumb-border: var(--ui-ranger-thumb-color);
--textarea-border-color: var(--secondary-color-4);
--chat-id-color: var(--theme-nuance-color-4);
/* ------------------------------------------- */
--button-alert-text-hover: var(--secondary-color-1);
--button-alert-color-hover: var(--theme-yellow-color);
--button-alert-border-hover: var(--theme-yellow-color);
--button-alert-text-active: var(--secondary-color-1);
--button-alert-color-active: var(--theme-orange-color);
--button-alert-border-active: var(--theme-orange-color);
/* ----------- PRIMARY BUTTONS --------------- */
/* - button should immediately catch the eye - */
--button-primary-text: var(--secondary-color-1);
--button-primary-color: var(--theme-nuance-color-3);
--button-primary-border: var(--theme-nuance-color-3);
/* ---------hover---------- */
--button-primary-text-hover:
hsl(217.5,
calc(var(--secondary-color-1-saturation) - 35%),
calc(var(--secondary-color-1-lightness) + 30%));
--button-primary-color-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 2%),
calc(var(--theme-nuance-color-3-lightness) - 10%));
--button-primary-border-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 2%),
calc(var(--theme-nuance-color-3-lightness) - 10%));
/* ---------active--------- */
--button-primary-text-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) + 35%));
--button-primary-color-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 10%),
calc(var(--theme-nuance-color-3-lightness) - 25%));
--button-primary-border-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 10%),
calc(var(--theme-nuance-color-3-lightness) - 25%));
/* ---------- SECONDARY BUTTONS -------------- */
/* these should NOT immediately catch the eye */
--button-secondary-text:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) - 50%));
--button-secondary-color:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) + 10%));
--button-secondary-border:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) + 10%));
/* ---------hover---------- */
--button-secondary-text-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) - 80%));
--button-secondary-color-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 22%),
calc(var(--theme-nuance-color-3-lightness) + 1%));
--button-secondary-border-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 22%),
calc(var(--theme-nuance-color-3-lightness) + 1%));
/* ---------active--------- */
--button-secondary-text-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) + 25%));
--button-secondary-color-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 30%),
calc(var(--theme-nuance-color-3-lightness) - 15%));
--button-secondary-border-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 30%),
calc(var(--theme-nuance-color-3-lightness) - 15%));
/* ---------- TERTIARY BUTTONS --------------- */
/* ---------- disabled buttons --------------- */
--button-tertiary-text:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) - 5%));
--button-tertiary-color:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) + 20%));
--button-tertiary-border:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) + 20%));
/* ---------hover---------- */
--button-tertiary-text-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) - 5%));
--button-tertiary-color-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) + 20%));
--button-tertiary-border-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) + 20%));
}

View file

@ -0,0 +1,251 @@
/* Author: Yazan Agha-Schrader */
/* Inspiration from Nord Theme https://www.nordtheme.com/docs/colors-and-palettes */
.theme-snowstorm {
/* ---------- PRIMARY COLORS ----------------- */
--primary-color-1: hsl(217.5, 26.7%, 94.1%);
--primary-color-1-hue: 217.5;
--primary-color-1-saturation: 26.7%;
--primary-color-1-lightness: 94.1%;
--primary-color-2: hsl(218.2, 26.8%, 92.0%);
--primary-color-2-hue: 218.2;
--primary-color-2-saturation: 26.8%;
--primary-color-2-lightness: 92.0%;
--primary-color-3: hsl(218.8, 27.9%, 88.0%);
--primary-color-3-hue: 218.8;
--primary-color-3-saturation: 27.9%;
--primary-color-3-lightness: 88.0%;
--primary-color-4: hsl(218.8, 18.3%, 81.8%);
--primary-color-4-hue: 218.8;
--primary-color-4-saturation: 18.3%;
--primary-color-4-lightness: 81.8%;
/* ---------- SECONDARY COLORS --------------- */
--secondary-color-1: hsl(220.0, 16.4%, 21.6%);
--secondary-color-1-hue: 220.0;
--secondary-color-1-saturation: 16.4%;
--secondary-color-1-lightness: 21.6%;
--secondary-color-2: hsl(221.7, 16.3%, 27.6%);
--secondary-color-2-hue: 221.7;
--secondary-color-2-saturation: 16.3%;
--secondary-color-2-lightness: 27.6%;
--secondary-color-3: hsl(220.0, 16.8%, 31.6%);
--secondary-color-3-hue: 220.0;
--secondary-color-3-saturation: 16.8%;
--secondary-color-3-lightness: 31.6%;
--secondary-color-4: hsl(220.0, 16.5%, 35.7%);
--secondary-color-4-hue: 220.0;
--secondary-color-4-saturation: 16.5%;
--secondary-color-4-lightness: 35.7%;
/* ----------- NUANCES COLORS ---------------- */
--theme-nuance-color-1: hsl(178.7, 25.1%, 64.9%);
--theme-nuance-color-1-hue: 178.7;
--theme-nuance-color-1-saturation: 25.1%;
--theme-nuance-color-1-lightness: 64.9%;
--theme-nuance-color-2: hsl(193.3, 43.4%, 67.5%);
--theme-nuance-color-2-hue: 193.3;
--theme-nuance-color-2-saturation: 43.4%;
--theme-nuance-color-2-lightness: 67.5%;
--theme-nuance-color-3: hsl(210.0, 34.0%, 63.1%);
--theme-nuance-color-3-hue: 210.0;
--theme-nuance-color-3-saturation: 34.0%;
--theme-nuance-color-3-lightness: 63.1%;
--theme-nuance-color-4: hsl(213.1, 32.0%, 52.2%);
--theme-nuance-color-4-hue: 213.1;
--theme-nuance-color-4-saturation: 32.0%;
--theme-nuance-color-4-lightness: 52.2%;
/* ----------- ROYGP COLORS ------------------ */
--theme-red-color: hsl(32.5, 80%, 50%);
--theme-orange-color: hsl(32.5, 70%, 45%);
--theme-yellow-color: hsl(40.0, 0.6%, 73.3%);
--theme-green-color: hsl(92.4, 27.8%, 64.7%);
--theme-purple-color: hsl(311.1, 20.2%, 63.1%);
/* ------------------------------------------- */
--background-color-1: var(--primary-color-1);
--background-color-2: var(--primary-color-2);
--background-color-3: var(--primary-color-3);
--background-color-4: var(--primary-color-4);
--border-color-1: var(--primary-color-2);
--border-color-2: var(--primary-color-3);
--border-color-3: var(--primary-color-4);
--border-focus-color: var(--theme-nuance-color-2);
--border-focus-shadow: var(--theme-nuance-color-1);
--text-color-plain: var(--secondary-color-1);
--text-color-subtile-1: var(--secondary-color-2);
--text-color-subtile-2: var(--secondary-color-3);
--code-background-color: var(--secondary-color-2);
--code-text-color: var(--primary-color-2);
--ui-range-thumb-color: var(--theme-nuance-color-3);
--ui-range-thumb-border: var(--ui-ranger-thumb-color);
--textarea-border-color: var(--secondary-color-4);
--chat-id-color: var(--theme-nuance-color-4);
/* ------------------------------------------- */
--button-alert-text-hover: var(--primary-color-1);
--button-alert-color-hover: var(--theme-orange-color);
--button-alert-border-hover: var(--theme-orange-color);
--button-alert-text-active: var(--primary-color-1);
--button-alert-color-active: var(--theme-red-color);
--button-alert-border-active: var(--theme-red-color);
/* ----------- PRIMARY BUTTONS --------------- */
/* - button should immediately catch the eye - */
--button-primary-text: var(--secondary-color-1);
--button-primary-color: var(--theme-nuance-color-3);
--button-primary-border: var(--theme-nuance-color-3);
/* ---------hover---------- */
--button-primary-text-hover:
hsl(217.5,
calc(var(--secondary-color-1-saturation) + 35%),
calc(var(--secondary-color-1-lightness) - 30%));
--button-primary-color-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 2%),
calc(var(--theme-nuance-color-3-lightness) - 10%));
--button-primary-border-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 2%),
calc(var(--theme-nuance-color-3-lightness) - 10%));
/* ---------active--------- */
--button-primary-text-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) + 35%));
--button-primary-color-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 10%),
calc(var(--theme-nuance-color-3-lightness) - 25%));
--button-primary-border-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 10%),
calc(var(--theme-nuance-color-3-lightness) - 25%));
/* ---------- SECONDARY BUTTONS -------------- */
/* these should NOT immediately catch the eye */
--button-secondary-text:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) - 50%));
--button-secondary-color:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) + 10%));
--button-secondary-border:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) + 10%));
/* ---------hover---------- */
--button-secondary-text-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 20%),
calc(var(--theme-nuance-color-3-lightness) - 80%));
--button-secondary-color-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 22%),
calc(var(--theme-nuance-color-3-lightness) + 1%));
--button-secondary-border-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 22%),
calc(var(--theme-nuance-color-3-lightness) + 1%));
/* ---------active--------- */
--button-secondary-text-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) + 40%),
calc(var(--theme-nuance-color-3-lightness) - 55%));
--button-secondary-color-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 30%),
calc(var(--theme-nuance-color-3-lightness) - 5%));
--button-secondary-border-active:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 30%),
calc(var(--theme-nuance-color-3-lightness) - 5%));
/* ---------- TERTIARY BUTTONS --------------- */
/* ---------- disabled buttons --------------- */
--button-tertiary-text:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) - 5%));
--button-tertiary-color:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) + 20%));
--button-tertiary-border:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) + 20%));
/* ---------hover---------- */
--button-tertiary-text-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) - 5%));
--button-tertiary-color-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) + 20%));
--button-tertiary-border-hover:
hsl(210,
calc(var(--theme-nuance-color-3-saturation) - 40%),
calc(var(--theme-nuance-color-3-lightness) + 20%));
}

View file

@ -0,0 +1,266 @@
//@ts-check
// Helpers to work with different data types
// by Humans for All
//
/**
* Given the limited context size of local LLMs and , many a times when context gets filled
* between the prompt and the response, it can lead to repeating text garbage generation.
* And many a times setting penalty wrt repeatation leads to over-intelligent garbage
* repeatation with slight variations. These garbage inturn can lead to overloading of the
* available model context, leading to less valuable response for subsequent prompts/queries,
* if chat history is sent to ai model.
*
* So two simple minded garbage trimming logics are experimented below.
* * one based on progressively-larger-substring-based-repeat-matching-with-partial-skip and
* * another based on char-histogram-driven garbage trimming.
* * in future characteristic of histogram over varying lengths could be used to allow for
* a more aggressive and adaptive trimming logic.
*/
/**
* Simple minded logic to help remove repeating garbage at end of the string.
* The repeatation needs to be perfectly matching.
*
* The logic progressively goes on probing for longer and longer substring based
* repeatation, till there is no longer repeatation. Inturn picks the one with
* the longest chain.
*
* @param {string} sIn
* @param {number} maxSubL
* @param {number} maxMatchLenThreshold
*/
export function trim_repeat_garbage_at_end(sIn, maxSubL=10, maxMatchLenThreshold=40) {
let rCnt = [0];
let maxMatchLen = maxSubL;
let iMML = -1;
for(let subL=1; subL < maxSubL; subL++) {
rCnt.push(0);
let i;
let refS = sIn.substring(sIn.length-subL, sIn.length);
for(i=sIn.length; i > 0; i -= subL) {
let curS = sIn.substring(i-subL, i);
if (refS != curS) {
let curMatchLen = rCnt[subL]*subL;
if (maxMatchLen < curMatchLen) {
maxMatchLen = curMatchLen;
iMML = subL;
}
break;
}
rCnt[subL] += 1;
}
}
console.debug("DBUG:DU:TrimRepeatGarbage:", rCnt);
if ((iMML == -1) || (maxMatchLen < maxMatchLenThreshold)) {
return {trimmed: false, data: sIn};
}
console.debug("DBUG:TrimRepeatGarbage:TrimmedCharLen:", maxMatchLen);
let iEnd = sIn.length - maxMatchLen;
return { trimmed: true, data: sIn.substring(0, iEnd) };
}
/**
* Simple minded logic to help remove repeating garbage at end of the string, till it cant.
* If its not able to trim, then it will try to skip a char at end and then trim, a few times.
* This ensures that even if there are multiple runs of garbage with different patterns, the
* logic still tries to munch through them.
*
* @param {string} sIn
* @param {number} maxSubL
* @param {number | undefined} [maxMatchLenThreshold]
*/
export function trim_repeat_garbage_at_end_loop(sIn, maxSubL, maxMatchLenThreshold, skipMax=16) {
let sCur = sIn;
let sSaved = "";
let iTry = 0;
while(true) {
let got = trim_repeat_garbage_at_end(sCur, maxSubL, maxMatchLenThreshold);
if (got.trimmed != true) {
if (iTry == 0) {
sSaved = got.data;
}
iTry += 1;
if (iTry >= skipMax) {
return sSaved;
}
got.data = got.data.substring(0,got.data.length-1);
} else {
iTry = 0;
}
sCur = got.data;
}
}
/**
* A simple minded try trim garbage at end using histogram driven characteristics.
* There can be variation in the repeatations, as long as no new char props up.
*
* This tracks the chars and their frequency in a specified length of substring at the end
* and inturn checks if moving further into the generated text from the end remains within
* the same char subset or goes beyond it and based on that either trims the string at the
* end or not. This allows to filter garbage at the end, including even if there are certain
* kind of small variations in the repeated text wrt position of seen chars.
*
* Allow the garbage to contain upto maxUniq chars, but at the same time ensure that
* a given type of char ie numerals or alphabets or other types dont cross the specified
* maxType limit. This allows intermixed text garbage to be identified and trimmed.
*
* ALERT: This is not perfect and only provides a rough garbage identification logic.
* Also it currently only differentiates between character classes wrt english.
*
* @param {string} sIn
* @param {number} maxType
* @param {number} maxUniq
* @param {number} maxMatchLenThreshold
*/
export function trim_hist_garbage_at_end(sIn, maxType, maxUniq, maxMatchLenThreshold) {
if (sIn.length < maxMatchLenThreshold) {
return { trimmed: false, data: sIn };
}
let iAlp = 0;
let iNum = 0;
let iOth = 0;
// Learn
let hist = {};
let iUniq = 0;
for(let i=0; i<maxMatchLenThreshold; i++) {
let c = sIn[sIn.length-1-i];
if (c in hist) {
hist[c] += 1;
} else {
if(c.match(/[0-9]/) != null) {
iNum += 1;
} else if(c.match(/[A-Za-z]/) != null) {
iAlp += 1;
} else {
iOth += 1;
}
iUniq += 1;
if (iUniq >= maxUniq) {
break;
}
hist[c] = 1;
}
}
console.debug("DBUG:TrimHistGarbage:", hist);
if ((iAlp > maxType) || (iNum > maxType) || (iOth > maxType)) {
return { trimmed: false, data: sIn };
}
// Catch and Trim
for(let i=0; i < sIn.length; i++) {
let c = sIn[sIn.length-1-i];
if (!(c in hist)) {
if (i < maxMatchLenThreshold) {
return { trimmed: false, data: sIn };
}
console.debug("DBUG:TrimHistGarbage:TrimmedCharLen:", i);
return { trimmed: true, data: sIn.substring(0, sIn.length-i+1) };
}
}
console.debug("DBUG:TrimHistGarbage:Trimmed fully");
return { trimmed: true, data: "" };
}
/**
* Keep trimming repeatedly using hist_garbage logic, till you no longer can.
* This ensures that even if there are multiple runs of garbage with different patterns,
* the logic still tries to munch through them.
*
* @param {any} sIn
* @param {number} maxType
* @param {number} maxUniq
* @param {number} maxMatchLenThreshold
*/
export function trim_hist_garbage_at_end_loop(sIn, maxType, maxUniq, maxMatchLenThreshold) {
let sCur = sIn;
while (true) {
let got = trim_hist_garbage_at_end(sCur, maxType, maxUniq, maxMatchLenThreshold);
if (!got.trimmed) {
return got.data;
}
sCur = got.data;
}
}
/**
* Try trim garbage at the end by using both the hist-driven-garbage-trimming as well as
* skip-a-bit-if-reqd-then-repeat-pattern-based-garbage-trimming, with blind retrying.
* @param {string} sIn
*/
export function trim_garbage_at_end(sIn) {
let sCur = sIn;
for(let i=0; i<2; i++) {
sCur = trim_hist_garbage_at_end_loop(sCur, 8, 24, 72);
sCur = trim_repeat_garbage_at_end_loop(sCur, 32, 72, 12);
}
return sCur;
}
/**
* NewLines array helper.
* Allow for maintaining a list of lines.
* Allow for a line to be builtup/appended part by part.
*/
export class NewLines {
constructor() {
/** @type {string[]} */
this.lines = [];
}
/**
* Extracts lines from the passed string and inturn either
* append to a previous partial line or add a new line.
* @param {string} sLines
*/
add_append(sLines) {
let aLines = sLines.split("\n");
let lCnt = 0;
for(let line of aLines) {
lCnt += 1;
// Add back newline removed if any during split
if (lCnt < aLines.length) {
line += "\n";
} else {
if (sLines.endsWith("\n")) {
line += "\n";
}
}
// Append if required
if (lCnt == 1) {
let lastLine = this.lines[this.lines.length-1];
if (lastLine != undefined) {
if (!lastLine.endsWith("\n")) {
this.lines[this.lines.length-1] += line;
continue;
}
}
}
// Add new line
this.lines.push(line);
}
}
/**
* Shift the oldest/earliest/0th line in the array. [Old-New|Earliest-Latest]
* Optionally control whether only full lines (ie those with newline at end) will be returned
* or will a partial line without a newline at end (can only be the last line) be returned.
* @param {boolean} bFullWithNewLineOnly
*/
shift(bFullWithNewLineOnly=true) {
let line = this.lines[0];
if (line == undefined) {
return undefined;
}
if ((line[line.length-1] != "\n") && bFullWithNewLineOnly){
return undefined;
}
return this.lines.shift();
}
}

View file

@ -0,0 +1,51 @@
<!DOCTYPE html>
<html lang="en">
<head>
<title>SimpleChat LlamaCppEtal </title>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="message" content="Save Nature Save Earth" />
<meta name="description" content="SimpleChat: trigger LLM web service endpoints /chat/completions and /completions, single/multi chat sessions" />
<meta name="author" content="by Humans for All" />
<meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate" />
<script type="importmap">
{
"imports": {
"datautils": "./datautils.mjs",
"ui": "./ui.mjs"
}
}
</script>
<script src="simplechat.js" type="module" defer></script>
<link rel="stylesheet" href="simplechat.css" />
</head>
<body>
<div class="samecolumn" id="fullbody">
<div class="sameline" id="heading">
<p class="heading flex-grow" > <b> SimpleChat </b> </p>
<button id="settings">Settings</button>
</div>
<div id="sessions-div" class="sameline"></div>
<hr>
<div class="sameline">
<label for="system-in">System</label>
<textarea name="system" id="system-in" rows="2" placeholder="e.g. you are a helpful ai assistant, who provides concise answers" class="flex-grow"></textarea>
</div>
<hr>
<div id="chat-div">
<p> You need to have javascript enabled.</p>
</div>
<hr>
<div class="sameline">
<textarea id="user-in" class="flex-grow" rows="2" placeholder="enter your query to the ai model here" ></textarea>
<button id="user-btn">submit</button>
</div>
</div>
</body>
</html>

View file

@ -0,0 +1,286 @@
# SimpleChat
by Humans for All.
## quickstart
To run from the build dir
bin/llama-server -m path/model.gguf --path ../tools/server/public_simplechat
Continue reading for the details.
## overview
This simple web frontend, allows triggering/testing the server's /completions or /chat/completions endpoints
in a simple way with minimal code from a common code base. Inturn additionally it tries to allow single or
multiple independent back and forth chatting to an extent, with the ai llm model at a basic level, with their
own system prompts.
This allows seeing the generated text / ai-model response in oneshot at the end, after it is fully generated,
or potentially as it is being generated, in a streamed manner from the server/ai-model.
![Chat and Settings screens](./simplechat_screens.webp "Chat and Settings screens")
Auto saves the chat session locally as and when the chat is progressing and inturn at a later time when you
open SimpleChat, option is provided to restore the old chat session, if a matching one exists.
The UI follows a responsive web design so that the layout can adapt to available display space in a usable
enough manner, in general.
Allows developer/end-user to control some of the behaviour by updating gMe members from browser's devel-tool
console. Parallely some of the directly useful to end-user settings can also be changed using the provided
settings ui.
NOTE: Current web service api doesnt expose the model context length directly, so client logic doesnt provide
any adaptive culling of old messages nor of replacing them with summary of their content etal. However there
is a optional sliding window based chat logic, which provides a simple minded culling of old messages from
the chat history before sending to the ai model.
NOTE: Wrt options sent with the request, it mainly sets temperature, max_tokens and optionaly stream for now.
However if someone wants they can update the js file or equivalent member in gMe as needed.
NOTE: One may be able to use this to chat with openai api web-service /chat/completions endpoint, in a very
limited / minimal way. One will need to set model, openai url and authorization bearer key in settings ui.
## usage
One could run this web frontend directly using server itself or if anyone is thinking of adding a built in web
frontend to configure the server over http(s) or so, then run this web frontend using something like python's
http module.
### running using tools/server
./llama-server -m path/model.gguf --path tools/server/public_simplechat [--port PORT]
### running using python3's server module
first run tools/server
* ./llama-server -m path/model.gguf
next run this web front end in tools/server/public_simplechat
* cd ../tools/server/public_simplechat
* python3 -m http.server PORT
### using the front end
Open this simple web front end from your local browser
* http://127.0.0.1:PORT/index.html
Once inside
* If you want to, you can change many of the default global settings
* the base url (ie ip addr / domain name, port)
* chat (default) vs completion mode
* try trim garbage in response or not
* amount of chat history in the context sent to server/ai-model
* oneshot or streamed mode.
* In completion mode
* one normally doesnt use a system prompt in completion mode.
* logic by default doesnt insert any role specific "ROLE: " prefix wrt each role's message.
If the model requires any prefix wrt user role messages, then the end user has to
explicitly add the needed prefix, when they enter their chat message.
Similarly if the model requires any prefix to trigger assistant/ai-model response,
then the end user needs to enter the same.
This keeps the logic simple, while still giving flexibility to the end user to
manage any templating/tagging requirement wrt their messages to the model.
* the logic doesnt insert newline at the begining and end wrt the prompt message generated.
However if the chat being sent to /completions end point has more than one role's message,
then insert newline when moving from one role's message to the next role's message, so
that it can be clearly identified/distinguished.
* given that /completions endpoint normally doesnt add additional chat-templating of its
own, the above ensures that end user can create a custom single/multi message combo with
any tags/special-tokens related chat templating to test out model handshake. Or enduser
can use it just for normal completion related/based query.
* If you want to provide a system prompt, then ideally enter it first, before entering any user query.
Normally Completion mode doesnt need system prompt, while Chat mode can generate better/interesting
responses with a suitable system prompt.
* if chat.add_system_begin is used
* you cant change the system prompt, after it is has been submitted once along with user query.
* you cant set a system prompt, after you have submitted any user query
* if chat.add_system_anytime is used
* one can change the system prompt any time during chat, by changing the contents of system prompt.
* inturn the updated/changed system prompt will be inserted into the chat session.
* this allows for the subsequent user chatting to be driven by the new system prompt set above.
* Enter your query and either press enter or click on the submit button.
If you want to insert enter (\n) as part of your chat/query to ai model, use shift+enter.
* Wait for the logic to communicate with the server and get the response.
* the user is not allowed to enter any fresh query during this time.
* the user input box will be disabled and a working message will be shown in it.
* if trim garbage is enabled, the logic will try to trim repeating text kind of garbage to some extent.
* just refresh the page, to reset wrt the chat history and or system prompt and start afresh.
* Using NewChat one can start independent chat sessions.
* two independent chat sessions are setup by default.
* When you want to print, switching ChatHistoryInCtxt to Full and clicking on the chat session button of
interest, will display the full chat history till then wrt same, if you want full history for printing.
## Devel note
### Reason behind this
The idea is to be easy enough to use for basic purposes, while also being simple and easily discernable
by developers who may not be from web frontend background (so inturn may not be familiar with template /
end-use-specific-language-extensions driven flows) so that they can use it to explore/experiment things.
And given that the idea is also to help explore/experiment for developers, some flexibility is provided
to change behaviour easily using the devel-tools/console or provided minimal settings ui (wrt few aspects).
Skeletal logic has been implemented to explore some of the end points and ideas/implications around them.
### General
Me/gMe consolidates the settings which control the behaviour into one object.
One can see the current settings, as well as change/update them using browsers devel-tool/console.
It is attached to the document object. Some of these can also be updated using the Settings UI.
baseURL - the domain-name/ip-address and inturn the port to send the request.
bStream - control between oneshot-at-end and live-stream-as-its-generated collating and showing
of the generated response.
the logic assumes that the text sent from the server follows utf-8 encoding.
in streaming mode - if there is any exception, the logic traps the same and tries to ensure
that text generated till then is not lost.
if a very long text is being generated, which leads to no user interaction for sometime and
inturn the machine goes into power saving mode or so, the platform may stop network connection,
leading to exception.
apiEP - select between /completions and /chat/completions endpoint provided by the server/ai-model.
bCompletionFreshChatAlways - whether Completion mode collates complete/sliding-window history when
communicating with the server or only sends the latest user query/message.
bCompletionInsertStandardRolePrefix - whether Completion mode inserts role related prefix wrt the
messages that get inserted into prompt field wrt /Completion endpoint.
bTrimGarbage - whether garbage repeatation at the end of the generated ai response, should be
trimmed or left as is. If enabled, it will be trimmed so that it wont be sent back as part of
subsequent chat history. At the same time the actual trimmed text is shown to the user, once
when it was generated, so user can check if any useful info/data was there in the response.
One may be able to request the ai-model to continue (wrt the last response) (if chat-history
is enabled as part of the chat-history-in-context setting), and chances are the ai-model will
continue starting from the trimmed part, thus allows long response to be recovered/continued
indirectly, in many cases.
The histogram/freq based trimming logic is currently tuned for english language wrt its
is-it-a-alpabetic|numeral-char regex match logic.
apiRequestOptions - maintains the list of options/fields to send along with api request,
irrespective of whether /chat/completions or /completions endpoint.
If you want to add additional options/fields to send to the server/ai-model, and or
modify the existing options value or remove them, for now you can update this global var
using browser's development-tools/console.
For string, numeric and boolean fields in apiRequestOptions, including even those added by a
user at runtime by directly modifying gMe.apiRequestOptions, setting ui entries will be auto
created.
cache_prompt option supported by example/server is allowed to be controlled by user, so that
any caching supported wrt system-prompt and chat history, if usable can get used. When chat
history sliding window is enabled, cache_prompt logic may or may not kick in at the backend
wrt same, based on aspects related to model, positional encoding, attention mechanism etal.
However system prompt should ideally get the benefit of caching.
headers - maintains the list of http headers sent when request is made to the server. By default
Content-Type is set to application/json. Additionally Authorization entry is provided, which can
be set if needed using the settings ui.
iRecentUserMsgCnt - a simple minded SlidingWindow to limit context window load at Ai Model end.
This is disabled by default. However if enabled, then in addition to latest system message, only
the last/latest iRecentUserMsgCnt user messages after the latest system prompt and its responses
from the ai model will be sent to the ai-model, when querying for a new response. IE if enabled,
only user messages after the latest system message/prompt will be considered.
This specified sliding window user message count also includes the latest user query.
<0 : Send entire chat history to server
0 : Send only the system message if any to the server
>0 : Send the latest chat history from the latest system prompt, limited to specified cnt.
By using gMe's iRecentUserMsgCnt and apiRequestOptions.max_tokens/n_predict one can try to control
the implications of loading of the ai-model's context window by chat history, wrt chat response to
some extent in a simple crude way. You may also want to control the context size enabled when the
server loads ai-model, on the server end.
Sometimes the browser may be stuborn with caching of the file, so your updates to html/css/js
may not be visible. Also remember that just refreshing/reloading page in browser or for that
matter clearing site data, dont directly override site caching in all cases. Worst case you may
have to change port. Or in dev tools of browser, you may be able to disable caching fully.
Currently the server to communicate with is maintained globally and not as part of a specific
chat session. So if one changes the server ip/url in setting, then all chat sessions will auto
switch to this new server, when you try using those sessions.
By switching between chat.add_system_begin/anytime, one can control whether one can change
the system prompt, anytime during the conversation or only at the beginning.
### Default setup
By default things are setup to try and make the user experience a bit better, if possible.
However a developer when testing the server of ai-model may want to change these value.
Using iRecentUserMsgCnt reduce chat history context sent to the server/ai-model to be
just the system-prompt, prev-user-request-and-ai-response and cur-user-request, instead of
full chat history. This way if there is any response with garbage/repeatation, it doesnt
mess with things beyond the next question/request/query, in some ways. The trim garbage
option also tries to help avoid issues with garbage in the context to an extent.
Set max_tokens to 1024, so that a relatively large previous reponse doesnt eat up the space
available wrt next query-response. However dont forget that the server when started should
also be started with a model context size of 1k or more, to be on safe side.
The /completions endpoint of tools/server doesnt take max_tokens, instead it takes the
internal n_predict, for now add the same here on the client side, maybe later add max_tokens
to /completions endpoint handling code on server side.
NOTE: One may want to experiment with frequency/presence penalty fields in apiRequestOptions
wrt the set of fields sent to server along with the user query, to check how the model behaves
wrt repeatations in general in the generated text response.
A end-user can change these behaviour by editing gMe from browser's devel-tool/console or by
using the provided settings ui (for settings exposed through the ui).
### OpenAi / Equivalent API WebService
One may be abe to handshake with OpenAI/Equivalent api web service's /chat/completions endpoint
for a minimal chatting experimentation by setting the below.
* the baseUrl in settings ui
* https://api.openai.com/v1 or similar
* Wrt request body - gMe.apiRequestOptions
* model (settings ui)
* any additional fields if required in future
* Wrt request headers - gMe.headers
* Authorization (available through settings ui)
* Bearer THE_OPENAI_API_KEY
* any additional optional header entries like "OpenAI-Organization", "OpenAI-Project" or so
NOTE: Not tested, as there is no free tier api testing available. However logically this might
work.
## At the end
Also a thank you to all open source and open model developers, who strive for the common good.

View file

@ -0,0 +1,79 @@
/**
* the styling of the simplechat web frontend
* by Humans for All
*/
#fullbody {
height: 98vh;
}
.heading {
background-color: lightgray;
}
.session-selected {
background-color: lightblue;
}
.role-system {
background-color: lightblue;
}
.role-user {
background-color: lightgray;
}
.role-trim {
background-color: lightpink;
}
.gridx2 {
display: grid;
grid-template-columns: repeat(2, 1fr);
border-bottom-style: dotted;
border-bottom-width: thin;
border-bottom-color: lightblue;
}
.flex-grow {
flex-grow: 1;
}
.float-right {
float: right;
}
#chat-div {
overflow: scroll;
flex-grow: 1;
flex-shrink: 1;
min-height: 40vh;
}
button {
min-width: 8vw;
}
.sameline {
display: flex;
flex-direction: row;
}
.samecolumn {
display: flex;
flex-direction: column;
}
.ul1 {
padding-inline-start: 2vw;
}
.ul2 {
padding-inline-start: 2vw;
}
* {
margin: 0.6vmin;
}
@media print {
#fullbody {
height: auto;
}
}

View file

@ -0,0 +1,929 @@
// @ts-check
// A simple completions and chat/completions test related web front end logic
// by Humans for All
import * as du from "./datautils.mjs";
import * as ui from "./ui.mjs"
class Roles {
static System = "system";
static User = "user";
static Assistant = "assistant";
}
class ApiEP {
static Type = {
Chat: "chat",
Completion: "completion",
}
static UrlSuffix = {
'chat': `/chat/completions`,
'completion': `/completions`,
}
/**
* Build the url from given baseUrl and apiEp id.
* @param {string} baseUrl
* @param {string} apiEP
*/
static Url(baseUrl, apiEP) {
if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.substring(0, baseUrl.length-1);
}
return `${baseUrl}${this.UrlSuffix[apiEP]}`;
}
}
let gUsageMsg = `
<p class="role-system">Usage</p>
<ul class="ul1">
<li> System prompt above, to try control ai response characteristics.</li>
<ul class="ul2">
<li> Completion mode - no system prompt normally.</li>
</ul>
<li> Use shift+enter for inserting enter/newline.</li>
<li> Enter your query to ai assistant below.</li>
<li> Default ContextWindow = [System, Last Query+Resp, Cur Query].</li>
<ul class="ul2">
<li> ChatHistInCtxt, MaxTokens, ModelCtxt window to expand</li>
</ul>
</ul>
`;
/** @typedef {{role: string, content: string}[]} ChatMessages */
/** @typedef {{iLastSys: number, xchat: ChatMessages}} SimpleChatODS */
class SimpleChat {
/**
* @param {string} chatId
*/
constructor(chatId) {
this.chatId = chatId;
/**
* Maintain in a form suitable for common LLM web service chat/completions' messages entry
* @type {ChatMessages}
*/
this.xchat = [];
this.iLastSys = -1;
this.latestResponse = "";
}
clear() {
this.xchat = [];
this.iLastSys = -1;
}
ods_key() {
return `SimpleChat-${this.chatId}`
}
save() {
/** @type {SimpleChatODS} */
let ods = {iLastSys: this.iLastSys, xchat: this.xchat};
localStorage.setItem(this.ods_key(), JSON.stringify(ods));
}
load() {
let sods = localStorage.getItem(this.ods_key());
if (sods == null) {
return;
}
/** @type {SimpleChatODS} */
let ods = JSON.parse(sods);
this.iLastSys = ods.iLastSys;
this.xchat = ods.xchat;
}
/**
* Recent chat messages.
* If iRecentUserMsgCnt < 0
* Then return the full chat history
* Else
* Return chat messages from latest going back till the last/latest system prompt.
* While keeping track that the number of user queries/messages doesnt exceed iRecentUserMsgCnt.
* @param {number} iRecentUserMsgCnt
*/
recent_chat(iRecentUserMsgCnt) {
if (iRecentUserMsgCnt < 0) {
return this.xchat;
}
if (iRecentUserMsgCnt == 0) {
console.warn("WARN:SimpleChat:SC:RecentChat:iRecentUsermsgCnt of 0 means no user message/query sent");
}
/** @type{ChatMessages} */
let rchat = [];
let sysMsg = this.get_system_latest();
if (sysMsg.length != 0) {
rchat.push({role: Roles.System, content: sysMsg});
}
let iUserCnt = 0;
let iStart = this.xchat.length;
for(let i=this.xchat.length-1; i > this.iLastSys; i--) {
if (iUserCnt >= iRecentUserMsgCnt) {
break;
}
let msg = this.xchat[i];
if (msg.role == Roles.User) {
iStart = i;
iUserCnt += 1;
}
}
for(let i = iStart; i < this.xchat.length; i++) {
let msg = this.xchat[i];
if (msg.role == Roles.System) {
continue;
}
rchat.push({role: msg.role, content: msg.content});
}
return rchat;
}
/**
* Collate the latest response from the server/ai-model, as it is becoming available.
* This is mainly useful for the stream mode.
* @param {string} content
*/
append_response(content) {
this.latestResponse += content;
}
/**
* Add an entry into xchat
* @param {string} role
* @param {string|undefined|null} content
*/
add(role, content) {
if ((content == undefined) || (content == null) || (content == "")) {
return false;
}
this.xchat.push( {role: role, content: content} );
if (role == Roles.System) {
this.iLastSys = this.xchat.length - 1;
}
this.save();
return true;
}
/**
* Show the contents in the specified div
* @param {HTMLDivElement} div
* @param {boolean} bClear
*/
show(div, bClear=true) {
if (bClear) {
div.replaceChildren();
}
let last = undefined;
for(const x of this.recent_chat(gMe.iRecentUserMsgCnt)) {
let entry = ui.el_create_append_p(`${x.role}: ${x.content}`, div);
entry.className = `role-${x.role}`;
last = entry;
}
if (last !== undefined) {
last.scrollIntoView(false);
} else {
if (bClear) {
div.innerHTML = gUsageMsg;
gMe.setup_load(div, this);
gMe.show_info(div);
}
}
return last;
}
/**
* Setup the fetch headers.
* It picks the headers from gMe.headers.
* It inserts Authorization only if its non-empty.
* @param {string} apiEP
*/
fetch_headers(apiEP) {
let headers = new Headers();
for(let k in gMe.headers) {
let v = gMe.headers[k];
if ((k == "Authorization") && (v.trim() == "")) {
continue;
}
headers.append(k, v);
}
return headers;
}
/**
* Add needed fields wrt json object to be sent wrt LLM web services completions endpoint.
* The needed fields/options are picked from a global object.
* Add optional stream flag, if required.
* Convert the json into string.
* @param {Object} obj
*/
request_jsonstr_extend(obj) {
for(let k in gMe.apiRequestOptions) {
obj[k] = gMe.apiRequestOptions[k];
}
if (gMe.bStream) {
obj["stream"] = true;
}
return JSON.stringify(obj);
}
/**
* Return a string form of json object suitable for chat/completions
*/
request_messages_jsonstr() {
let req = {
messages: this.recent_chat(gMe.iRecentUserMsgCnt),
}
return this.request_jsonstr_extend(req);
}
/**
* Return a string form of json object suitable for /completions
* @param {boolean} bInsertStandardRolePrefix Insert "<THE_ROLE>: " as prefix wrt each role's message
*/
request_prompt_jsonstr(bInsertStandardRolePrefix) {
let prompt = "";
let iCnt = 0;
for(const chat of this.recent_chat(gMe.iRecentUserMsgCnt)) {
iCnt += 1;
if (iCnt > 1) {
prompt += "\n";
}
if (bInsertStandardRolePrefix) {
prompt += `${chat.role}: `;
}
prompt += `${chat.content}`;
}
let req = {
prompt: prompt,
}
return this.request_jsonstr_extend(req);
}
/**
* Return a string form of json object suitable for specified api endpoint.
* @param {string} apiEP
*/
request_jsonstr(apiEP) {
if (apiEP == ApiEP.Type.Chat) {
return this.request_messages_jsonstr();
} else {
return this.request_prompt_jsonstr(gMe.bCompletionInsertStandardRolePrefix);
}
}
/**
* Extract the ai-model/assistant's response from the http response got.
* Optionally trim the message wrt any garbage at the end.
* @param {any} respBody
* @param {string} apiEP
*/
response_extract(respBody, apiEP) {
let assistant = "";
if (apiEP == ApiEP.Type.Chat) {
assistant = respBody["choices"][0]["message"]["content"];
} else {
try {
assistant = respBody["choices"][0]["text"];
} catch {
assistant = respBody["content"];
}
}
return assistant;
}
/**
* Extract the ai-model/assistant's response from the http response got in streaming mode.
* @param {any} respBody
* @param {string} apiEP
*/
response_extract_stream(respBody, apiEP) {
let assistant = "";
if (apiEP == ApiEP.Type.Chat) {
if (respBody["choices"][0]["finish_reason"] !== "stop") {
assistant = respBody["choices"][0]["delta"]["content"];
}
} else {
try {
assistant = respBody["choices"][0]["text"];
} catch {
assistant = respBody["content"];
}
}
return assistant;
}
/**
* Allow setting of system prompt, but only at begining.
* @param {string} sysPrompt
* @param {string} msgTag
*/
add_system_begin(sysPrompt, msgTag) {
if (this.xchat.length == 0) {
if (sysPrompt.length > 0) {
return this.add(Roles.System, sysPrompt);
}
} else {
if (sysPrompt.length > 0) {
if (this.xchat[0].role !== Roles.System) {
console.error(`ERRR:SimpleChat:SC:${msgTag}:You need to specify system prompt before any user query, ignoring...`);
} else {
if (this.xchat[0].content !== sysPrompt) {
console.error(`ERRR:SimpleChat:SC:${msgTag}:You cant change system prompt, mid way through, ignoring...`);
}
}
}
}
return false;
}
/**
* Allow setting of system prompt, at any time.
* @param {string} sysPrompt
* @param {string} msgTag
*/
add_system_anytime(sysPrompt, msgTag) {
if (sysPrompt.length <= 0) {
return false;
}
if (this.iLastSys < 0) {
return this.add(Roles.System, sysPrompt);
}
let lastSys = this.xchat[this.iLastSys].content;
if (lastSys !== sysPrompt) {
return this.add(Roles.System, sysPrompt);
}
return false;
}
/**
* Retrieve the latest system prompt.
*/
get_system_latest() {
if (this.iLastSys == -1) {
return "";
}
let sysPrompt = this.xchat[this.iLastSys].content;
return sysPrompt;
}
/**
* Handle the multipart response from server/ai-model
* @param {Response} resp
* @param {string} apiEP
* @param {HTMLDivElement} elDiv
*/
async handle_response_multipart(resp, apiEP, elDiv) {
let elP = ui.el_create_append_p("", elDiv);
if (!resp.body) {
throw Error("ERRR:SimpleChat:SC:HandleResponseMultiPart:No body...");
}
let tdUtf8 = new TextDecoder("utf-8");
let rr = resp.body.getReader();
this.latestResponse = "";
let xLines = new du.NewLines();
while(true) {
let { value: cur, done: done } = await rr.read();
if (cur) {
let curBody = tdUtf8.decode(cur, {stream: true});
console.debug("DBUG:SC:PART:Str:", curBody);
xLines.add_append(curBody);
}
while(true) {
let curLine = xLines.shift(!done);
if (curLine == undefined) {
break;
}
if (curLine.trim() == "") {
continue;
}
if (curLine.startsWith("data:")) {
curLine = curLine.substring(5);
}
if (curLine.trim() === "[DONE]") {
break;
}
let curJson = JSON.parse(curLine);
console.debug("DBUG:SC:PART:Json:", curJson);
this.append_response(this.response_extract_stream(curJson, apiEP));
}
elP.innerText = this.latestResponse;
elP.scrollIntoView(false);
if (done) {
break;
}
}
console.debug("DBUG:SC:PART:Full:", this.latestResponse);
return this.latestResponse;
}
/**
* Handle the oneshot response from server/ai-model
* @param {Response} resp
* @param {string} apiEP
*/
async handle_response_oneshot(resp, apiEP) {
let respBody = await resp.json();
console.debug(`DBUG:SimpleChat:SC:${this.chatId}:HandleUserSubmit:RespBody:${JSON.stringify(respBody)}`);
return this.response_extract(respBody, apiEP);
}
/**
* Handle the response from the server be it in oneshot or multipart/stream mode.
* Also take care of the optional garbage trimming.
* @param {Response} resp
* @param {string} apiEP
* @param {HTMLDivElement} elDiv
*/
async handle_response(resp, apiEP, elDiv) {
let theResp = {
assistant: "",
trimmed: "",
}
if (gMe.bStream) {
try {
theResp.assistant = await this.handle_response_multipart(resp, apiEP, elDiv);
this.latestResponse = "";
} catch (error) {
theResp.assistant = this.latestResponse;
this.add(Roles.Assistant, theResp.assistant);
this.latestResponse = "";
throw error;
}
} else {
theResp.assistant = await this.handle_response_oneshot(resp, apiEP);
}
if (gMe.bTrimGarbage) {
let origMsg = theResp.assistant;
theResp.assistant = du.trim_garbage_at_end(origMsg);
theResp.trimmed = origMsg.substring(theResp.assistant.length);
}
this.add(Roles.Assistant, theResp.assistant);
return theResp;
}
}
class MultiChatUI {
constructor() {
/** @type {Object<string, SimpleChat>} */
this.simpleChats = {};
/** @type {string} */
this.curChatId = "";
// the ui elements
this.elInSystem = /** @type{HTMLInputElement} */(document.getElementById("system-in"));
this.elDivChat = /** @type{HTMLDivElement} */(document.getElementById("chat-div"));
this.elBtnUser = /** @type{HTMLButtonElement} */(document.getElementById("user-btn"));
this.elInUser = /** @type{HTMLInputElement} */(document.getElementById("user-in"));
this.elDivHeading = /** @type{HTMLSelectElement} */(document.getElementById("heading"));
this.elDivSessions = /** @type{HTMLDivElement} */(document.getElementById("sessions-div"));
this.elBtnSettings = /** @type{HTMLButtonElement} */(document.getElementById("settings"));
this.validate_element(this.elInSystem, "system-in");
this.validate_element(this.elDivChat, "chat-div");
this.validate_element(this.elInUser, "user-in");
this.validate_element(this.elDivHeading, "heading");
this.validate_element(this.elDivChat, "sessions-div");
this.validate_element(this.elBtnSettings, "settings");
}
/**
* Check if the element got
* @param {HTMLElement | null} el
* @param {string} msgTag
*/
validate_element(el, msgTag) {
if (el == null) {
throw Error(`ERRR:SimpleChat:MCUI:${msgTag} element missing in html...`);
} else {
console.debug(`INFO:SimpleChat:MCUI:${msgTag} Id[${el.id}] Name[${el["name"]}]`);
}
}
/**
* Reset user input ui.
* * clear user input
* * enable user input
* * set focus to user input
*/
ui_reset_userinput() {
this.elInUser.value = "";
this.elInUser.disabled = false;
this.elInUser.focus();
}
/**
* Setup the needed callbacks wrt UI, curChatId to defaultChatId and
* optionally switch to specified defaultChatId.
* @param {string} defaultChatId
* @param {boolean} bSwitchSession
*/
setup_ui(defaultChatId, bSwitchSession=false) {
this.curChatId = defaultChatId;
if (bSwitchSession) {
this.handle_session_switch(this.curChatId);
}
this.elBtnSettings.addEventListener("click", (ev)=>{
this.elDivChat.replaceChildren();
gMe.show_settings(this.elDivChat);
});
this.elBtnUser.addEventListener("click", (ev)=>{
if (this.elInUser.disabled) {
return;
}
this.handle_user_submit(this.curChatId, gMe.apiEP).catch((/** @type{Error} */reason)=>{
let msg = `ERRR:SimpleChat\nMCUI:HandleUserSubmit:${this.curChatId}\n${reason.name}:${reason.message}`;
console.error(msg.replace("\n", ":"));
alert(msg);
this.ui_reset_userinput();
});
});
this.elInUser.addEventListener("keyup", (ev)=> {
// allow user to insert enter into their message using shift+enter.
// while just pressing enter key will lead to submitting.
if ((ev.key === "Enter") && (!ev.shiftKey)) {
let value = this.elInUser.value;
this.elInUser.value = value.substring(0,value.length-1);
this.elBtnUser.click();
ev.preventDefault();
}
});
this.elInSystem.addEventListener("keyup", (ev)=> {
// allow user to insert enter into the system prompt using shift+enter.
// while just pressing enter key will lead to setting the system prompt.
if ((ev.key === "Enter") && (!ev.shiftKey)) {
let value = this.elInSystem.value;
this.elInSystem.value = value.substring(0,value.length-1);
let chat = this.simpleChats[this.curChatId];
chat.add_system_anytime(this.elInSystem.value, this.curChatId);
chat.show(this.elDivChat);
ev.preventDefault();
}
});
}
/**
* Setup a new chat session and optionally switch to it.
* @param {string} chatId
* @param {boolean} bSwitchSession
*/
new_chat_session(chatId, bSwitchSession=false) {
this.simpleChats[chatId] = new SimpleChat(chatId);
if (bSwitchSession) {
this.handle_session_switch(chatId);
}
}
/**
* Handle user query submit request, wrt specified chat session.
* @param {string} chatId
* @param {string} apiEP
*/
async handle_user_submit(chatId, apiEP) {
let chat = this.simpleChats[chatId];
// In completion mode, if configured, clear any previous chat history.
// So if user wants to simulate a multi-chat based completion query,
// they will have to enter the full thing, as a suitable multiline
// user input/query.
if ((apiEP == ApiEP.Type.Completion) && (gMe.bCompletionFreshChatAlways)) {
chat.clear();
}
chat.add_system_anytime(this.elInSystem.value, chatId);
let content = this.elInUser.value;
if (!chat.add(Roles.User, content)) {
console.debug(`WARN:SimpleChat:MCUI:${chatId}:HandleUserSubmit:Ignoring empty user input...`);
return;
}
chat.show(this.elDivChat);
let theUrl = ApiEP.Url(gMe.baseURL, apiEP);
let theBody = chat.request_jsonstr(apiEP);
this.elInUser.value = "working...";
this.elInUser.disabled = true;
console.debug(`DBUG:SimpleChat:MCUI:${chatId}:HandleUserSubmit:${theUrl}:ReqBody:${theBody}`);
let theHeaders = chat.fetch_headers(apiEP);
let resp = await fetch(theUrl, {
method: "POST",
headers: theHeaders,
body: theBody,
});
let theResp = await chat.handle_response(resp, apiEP, this.elDivChat);
if (chatId == this.curChatId) {
chat.show(this.elDivChat);
if (theResp.trimmed.length > 0) {
let p = ui.el_create_append_p(`TRIMMED:${theResp.trimmed}`, this.elDivChat);
p.className="role-trim";
}
} else {
console.debug(`DBUG:SimpleChat:MCUI:HandleUserSubmit:ChatId has changed:[${chatId}] [${this.curChatId}]`);
}
this.ui_reset_userinput();
}
/**
* Show buttons for NewChat and available chat sessions, in the passed elDiv.
* If elDiv is undefined/null, then use this.elDivSessions.
* Take care of highlighting the selected chat-session's btn.
* @param {HTMLDivElement | undefined} elDiv
*/
show_sessions(elDiv=undefined) {
if (!elDiv) {
elDiv = this.elDivSessions;
}
elDiv.replaceChildren();
// Btn for creating new chat session
let btnNew = ui.el_create_button("New CHAT", (ev)=> {
if (this.elInUser.disabled) {
console.error(`ERRR:SimpleChat:MCUI:NewChat:Current session [${this.curChatId}] awaiting response, ignoring request...`);
alert("ERRR:SimpleChat\nMCUI:NewChat\nWait for response to pending query, before starting new chat session");
return;
}
let chatId = `Chat${Object.keys(this.simpleChats).length}`;
let chatIdGot = prompt("INFO:SimpleChat\nMCUI:NewChat\nEnter id for new chat session", chatId);
if (!chatIdGot) {
console.error("ERRR:SimpleChat:MCUI:NewChat:Skipping based on user request...");
return;
}
this.new_chat_session(chatIdGot, true);
this.create_session_btn(elDiv, chatIdGot);
ui.el_children_config_class(elDiv, chatIdGot, "session-selected", "");
});
elDiv.appendChild(btnNew);
// Btns for existing chat sessions
let chatIds = Object.keys(this.simpleChats);
for(let cid of chatIds) {
let btn = this.create_session_btn(elDiv, cid);
if (cid == this.curChatId) {
btn.className = "session-selected";
}
}
}
create_session_btn(elDiv, cid) {
let btn = ui.el_create_button(cid, (ev)=>{
let target = /** @type{HTMLButtonElement} */(ev.target);
console.debug(`DBUG:SimpleChat:MCUI:SessionClick:${target.id}`);
if (this.elInUser.disabled) {
console.error(`ERRR:SimpleChat:MCUI:SessionClick:${target.id}:Current session [${this.curChatId}] awaiting response, ignoring switch...`);
alert("ERRR:SimpleChat\nMCUI:SessionClick\nWait for response to pending query, before switching");
return;
}
this.handle_session_switch(target.id);
ui.el_children_config_class(elDiv, target.id, "session-selected", "");
});
elDiv.appendChild(btn);
return btn;
}
/**
* Switch ui to the specified chatId and set curChatId to same.
* @param {string} chatId
*/
async handle_session_switch(chatId) {
let chat = this.simpleChats[chatId];
if (chat == undefined) {
console.error(`ERRR:SimpleChat:MCUI:HandleSessionSwitch:${chatId} missing...`);
return;
}
this.elInSystem.value = chat.get_system_latest();
this.elInUser.value = "";
chat.show(this.elDivChat);
this.elInUser.focus();
this.curChatId = chatId;
console.log(`INFO:SimpleChat:MCUI:HandleSessionSwitch:${chatId} entered...`);
}
}
class Me {
constructor() {
this.baseURL = "http://127.0.0.1:8080";
this.defaultChatIds = [ "Default", "Other" ];
this.multiChat = new MultiChatUI();
this.bStream = true;
this.bCompletionFreshChatAlways = true;
this.bCompletionInsertStandardRolePrefix = false;
this.bTrimGarbage = true;
this.iRecentUserMsgCnt = 2;
this.sRecentUserMsgCnt = {
"Full": -1,
"Last0": 1,
"Last1": 2,
"Last2": 3,
"Last4": 5,
};
this.apiEP = ApiEP.Type.Chat;
this.headers = {
"Content-Type": "application/json",
"Authorization": "", // Authorization: Bearer OPENAI_API_KEY
}
// Add needed fields wrt json object to be sent wrt LLM web services completions endpoint.
this.apiRequestOptions = {
"model": "gpt-3.5-turbo",
"temperature": 0.7,
"max_tokens": 1024,
"n_predict": 1024,
"cache_prompt": false,
//"frequency_penalty": 1.2,
//"presence_penalty": 1.2,
};
}
/**
* Disable console.debug by mapping it to a empty function.
*/
debug_disable() {
this.console_debug = console.debug;
console.debug = () => {
};
}
/**
* Setup the load saved chat ui.
* @param {HTMLDivElement} div
* @param {SimpleChat} chat
*/
setup_load(div, chat) {
if (!(chat.ods_key() in localStorage)) {
return;
}
div.innerHTML += `<p class="role-system">Restore</p>
<p>Load previously saved chat session, if available</p>`;
let btn = ui.el_create_button(chat.ods_key(), (ev)=>{
console.log("DBUG:SimpleChat:SC:Load", chat);
chat.load();
queueMicrotask(()=>{
chat.show(div);
this.multiChat.elInSystem.value = chat.get_system_latest();
});
});
div.appendChild(btn);
}
/**
* Show the configurable parameters info in the passed Div element.
* @param {HTMLDivElement} elDiv
* @param {boolean} bAll
*/
show_info(elDiv, bAll=false) {
let p = ui.el_create_append_p("Settings (devel-tools-console document[gMe])", elDiv);
p.className = "role-system";
if (bAll) {
ui.el_create_append_p(`baseURL:${this.baseURL}`, elDiv);
ui.el_create_append_p(`Authorization:${this.headers["Authorization"]}`, elDiv);
ui.el_create_append_p(`bStream:${this.bStream}`, elDiv);
ui.el_create_append_p(`bTrimGarbage:${this.bTrimGarbage}`, elDiv);
ui.el_create_append_p(`ApiEndPoint:${this.apiEP}`, elDiv);
ui.el_create_append_p(`iRecentUserMsgCnt:${this.iRecentUserMsgCnt}`, elDiv);
ui.el_create_append_p(`bCompletionFreshChatAlways:${this.bCompletionFreshChatAlways}`, elDiv);
ui.el_create_append_p(`bCompletionInsertStandardRolePrefix:${this.bCompletionInsertStandardRolePrefix}`, elDiv);
}
ui.el_create_append_p(`apiRequestOptions:${JSON.stringify(this.apiRequestOptions, null, " - ")}`, elDiv);
ui.el_create_append_p(`headers:${JSON.stringify(this.headers, null, " - ")}`, elDiv);
}
/**
* Auto create ui input elements for fields in apiRequestOptions
* Currently supports text and number field types.
* @param {HTMLDivElement} elDiv
*/
show_settings_apirequestoptions(elDiv) {
let typeDict = {
"string": "text",
"number": "number",
};
let fs = document.createElement("fieldset");
let legend = document.createElement("legend");
legend.innerText = "ApiRequestOptions";
fs.appendChild(legend);
elDiv.appendChild(fs);
for(const k in this.apiRequestOptions) {
let val = this.apiRequestOptions[k];
let type = typeof(val);
if (((type == "string") || (type == "number"))) {
let inp = ui.el_creatediv_input(`Set${k}`, k, typeDict[type], this.apiRequestOptions[k], (val)=>{
if (type == "number") {
val = Number(val);
}
this.apiRequestOptions[k] = val;
});
fs.appendChild(inp.div);
} else if (type == "boolean") {
let bbtn = ui.el_creatediv_boolbutton(`Set{k}`, k, {true: "true", false: "false"}, val, (userVal)=>{
this.apiRequestOptions[k] = userVal;
});
fs.appendChild(bbtn.div);
}
}
}
/**
* Show settings ui for configurable parameters, in the passed Div element.
* @param {HTMLDivElement} elDiv
*/
show_settings(elDiv) {
let inp = ui.el_creatediv_input("SetBaseURL", "BaseURL", "text", this.baseURL, (val)=>{
this.baseURL = val;
});
elDiv.appendChild(inp.div);
inp = ui.el_creatediv_input("SetAuthorization", "Authorization", "text", this.headers["Authorization"], (val)=>{
this.headers["Authorization"] = val;
});
inp.el.placeholder = "Bearer OPENAI_API_KEY";
elDiv.appendChild(inp.div);
let bb = ui.el_creatediv_boolbutton("SetStream", "Stream", {true: "[+] yes stream", false: "[-] do oneshot"}, this.bStream, (val)=>{
this.bStream = val;
});
elDiv.appendChild(bb.div);
bb = ui.el_creatediv_boolbutton("SetTrimGarbage", "TrimGarbage", {true: "[+] yes trim", false: "[-] dont trim"}, this.bTrimGarbage, (val)=>{
this.bTrimGarbage = val;
});
elDiv.appendChild(bb.div);
this.show_settings_apirequestoptions(elDiv);
let sel = ui.el_creatediv_select("SetApiEP", "ApiEndPoint", ApiEP.Type, this.apiEP, (val)=>{
this.apiEP = ApiEP.Type[val];
});
elDiv.appendChild(sel.div);
sel = ui.el_creatediv_select("SetChatHistoryInCtxt", "ChatHistoryInCtxt", this.sRecentUserMsgCnt, this.iRecentUserMsgCnt, (val)=>{
this.iRecentUserMsgCnt = this.sRecentUserMsgCnt[val];
});
elDiv.appendChild(sel.div);
bb = ui.el_creatediv_boolbutton("SetCompletionFreshChatAlways", "CompletionFreshChatAlways", {true: "[+] yes fresh", false: "[-] no, with history"}, this.bCompletionFreshChatAlways, (val)=>{
this.bCompletionFreshChatAlways = val;
});
elDiv.appendChild(bb.div);
bb = ui.el_creatediv_boolbutton("SetCompletionInsertStandardRolePrefix", "CompletionInsertStandardRolePrefix", {true: "[+] yes insert", false: "[-] dont insert"}, this.bCompletionInsertStandardRolePrefix, (val)=>{
this.bCompletionInsertStandardRolePrefix = val;
});
elDiv.appendChild(bb.div);
}
}
/** @type {Me} */
let gMe;
function startme() {
console.log("INFO:SimpleChat:StartMe:Starting...");
gMe = new Me();
gMe.debug_disable();
document["gMe"] = gMe;
document["du"] = du;
for (let cid of gMe.defaultChatIds) {
gMe.multiChat.new_chat_session(cid);
}
gMe.multiChat.setup_ui(gMe.defaultChatIds[0], true);
gMe.multiChat.show_sessions();
}
document.addEventListener("DOMContentLoaded", startme);

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

View file

@ -0,0 +1,211 @@
//@ts-check
// Helpers to work with html elements
// by Humans for All
//
/**
* Set the class of the children, based on whether it is the idSelected or not.
* @param {HTMLDivElement} elBase
* @param {string} idSelected
* @param {string} classSelected
* @param {string} classUnSelected
*/
export function el_children_config_class(elBase, idSelected, classSelected, classUnSelected="") {
for(let child of elBase.children) {
if (child.id == idSelected) {
child.className = classSelected;
} else {
child.className = classUnSelected;
}
}
}
/**
* Create button and set it up.
* @param {string} id
* @param {(this: HTMLButtonElement, ev: MouseEvent) => any} callback
* @param {string | undefined} name
* @param {string | undefined} innerText
*/
export function el_create_button(id, callback, name=undefined, innerText=undefined) {
if (!name) {
name = id;
}
if (!innerText) {
innerText = id;
}
let btn = document.createElement("button");
btn.id = id;
btn.name = name;
btn.innerText = innerText;
btn.addEventListener("click", callback);
return btn;
}
/**
* Create a para and set it up. Optionaly append it to a passed parent.
* @param {string} text
* @param {HTMLElement | undefined} elParent
* @param {string | undefined} id
*/
export function el_create_append_p(text, elParent=undefined, id=undefined) {
let para = document.createElement("p");
para.innerText = text;
if (id) {
para.id = id;
}
if (elParent) {
elParent.appendChild(para);
}
return para;
}
/**
* Create a button which represents bool value using specified text wrt true and false.
* When ever user clicks the button, it will toggle the value and update the shown text.
*
* @param {string} id
* @param {{true: string, false: string}} texts
* @param {boolean} defaultValue
* @param {function(boolean):void} cb
*/
export function el_create_boolbutton(id, texts, defaultValue, cb) {
let el = document.createElement("button");
el["xbool"] = defaultValue;
el["xtexts"] = structuredClone(texts);
el.innerText = el["xtexts"][String(defaultValue)];
if (id) {
el.id = id;
}
el.addEventListener('click', (ev)=>{
el["xbool"] = !el["xbool"];
el.innerText = el["xtexts"][String(el["xbool"])];
cb(el["xbool"]);
})
return el;
}
/**
* Create a div wrapped button which represents bool value using specified text wrt true and false.
* @param {string} id
* @param {string} label
* @param {{ true: string; false: string; }} texts
* @param {boolean} defaultValue
* @param {(arg0: boolean) => void} cb
* @param {string} className
*/
export function el_creatediv_boolbutton(id, label, texts, defaultValue, cb, className="gridx2") {
let div = document.createElement("div");
div.className = className;
let lbl = document.createElement("label");
lbl.setAttribute("for", id);
lbl.innerText = label;
div.appendChild(lbl);
let btn = el_create_boolbutton(id, texts, defaultValue, cb);
div.appendChild(btn);
return { div: div, el: btn };
}
/**
* Create a select ui element, with a set of options to select from.
* * options: an object which contains name-value pairs
* * defaultOption: the value whose name should be choosen, by default.
* * cb : the call back returns the name string of the option selected.
*
* @param {string} id
* @param {Object<string,*>} options
* @param {*} defaultOption
* @param {function(string):void} cb
*/
export function el_create_select(id, options, defaultOption, cb) {
let el = document.createElement("select");
el["xselected"] = defaultOption;
el["xoptions"] = structuredClone(options);
for(let cur of Object.keys(options)) {
let op = document.createElement("option");
op.value = cur;
op.innerText = cur;
if (options[cur] == defaultOption) {
op.selected = true;
}
el.appendChild(op);
}
if (id) {
el.id = id;
el.name = id;
}
el.addEventListener('change', (ev)=>{
let target = /** @type{HTMLSelectElement} */(ev.target);
console.log("DBUG:UI:Select:", id, ":", target.value);
cb(target.value);
})
return el;
}
/**
* Create a div wrapped select ui element, with a set of options to select from.
*
* @param {string} id
* @param {any} label
* @param {{ [x: string]: any; }} options
* @param {any} defaultOption
* @param {(arg0: string) => void} cb
* @param {string} className
*/
export function el_creatediv_select(id, label, options, defaultOption, cb, className="gridx2") {
let div = document.createElement("div");
div.className = className;
let lbl = document.createElement("label");
lbl.setAttribute("for", id);
lbl.innerText = label;
div.appendChild(lbl);
let sel = el_create_select(id, options,defaultOption, cb);
div.appendChild(sel);
return { div: div, el: sel };
}
/**
* Create a input ui element.
*
* @param {string} id
* @param {string} type
* @param {any} defaultValue
* @param {function(any):void} cb
*/
export function el_create_input(id, type, defaultValue, cb) {
let el = document.createElement("input");
el.type = type;
el.value = defaultValue;
if (id) {
el.id = id;
}
el.addEventListener('change', (ev)=>{
cb(el.value);
})
return el;
}
/**
* Create a div wrapped input.
*
* @param {string} id
* @param {string} label
* @param {string} type
* @param {any} defaultValue
* @param {function(any):void} cb
* @param {string} className
*/
export function el_creatediv_input(id, label, type, defaultValue, cb, className="gridx2") {
let div = document.createElement("div");
div.className = className;
let lbl = document.createElement("label");
lbl.setAttribute("for", id);
lbl.innerText = label;
div.appendChild(lbl);
let el = el_create_input(id, type, defaultValue, cb);
div.appendChild(el);
return { div: div, el: el };
}

4640
tools/server/server.cpp Normal file

File diff suppressed because it is too large Load diff

2
tools/server/tests/.gitignore vendored Normal file
View file

@ -0,0 +1,2 @@
.venv
tmp

View file

@ -0,0 +1,66 @@
# Server tests
Python based server tests scenario using [pytest](https://docs.pytest.org/en/stable/).
Tests target GitHub workflows job runners with 4 vCPU.
Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail.
To mitigate it, you can increase values in `n_predict`, `kv_size`.
### Install dependencies
`pip install -r requirements.txt`
### Run tests
1. Build the server
```shell
cd ../../..
cmake -B build
cmake --build build --target llama-server
```
2. Start the test: `./tests.sh`
It's possible to override some scenario steps values with environment variables:
| variable | description |
|--------------------------|------------------------------------------------------------------------------------------------|
| `PORT` | `context.server_port` to set the listening port of the server during scenario, default: `8080` |
| `LLAMA_SERVER_BIN_PATH` | to change the server binary path, default: `../../../build/bin/llama-server` |
| `DEBUG` | to enable steps and server verbose mode `--verbose` |
| `N_GPU_LAYERS` | number of model layers to offload to VRAM `-ngl --n-gpu-layers` |
| `LLAMA_CACHE` | by default server tests re-download models to the `tmp` subfolder. Set this to your cache (e.g. `$HOME/Library/Caches/llama.cpp` on Mac or `$HOME/.cache/llama.cpp` on Unix) to avoid this |
To run slow tests (will download many models, make sure to set `LLAMA_CACHE` if needed):
```shell
SLOW_TESTS=1 ./tests.sh
```
To run with stdout/stderr display in real time (verbose output, but useful for debugging):
```shell
DEBUG=1 ./tests.sh -s -v -x
```
To run all the tests in a file:
```shell
./tests.sh unit/test_chat_completion.py -v -x
```
To run a single test:
```shell
./tests.sh unit/test_chat_completion.py::test_invalid_chat_completion_req
```
Hint: You can compile and run test in single command, useful for local developement:
```shell
cmake --build build -j --target llama-server && ./tools/server/tests/tests.sh
```
To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html)

View file

@ -0,0 +1,15 @@
import pytest
from utils import *
# ref: https://stackoverflow.com/questions/22627659/run-code-before-and-after-each-test-in-py-test
@pytest.fixture(autouse=True)
def stop_server_after_each_test():
# do nothing before each test
yield
# stop all servers after each test
instances = set(
server_instances
) # copy the set to prevent 'Set changed size during iteration'
for server in instances:
server.stop()

View file

@ -0,0 +1,4 @@
[pytest]
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
serial

View file

@ -0,0 +1,8 @@
aiohttp~=3.9.3
pytest~=8.3.3
huggingface_hub~=0.23.2
numpy~=1.26.4
openai~=1.55.3
prometheus-client~=0.20.0
requests~=2.32.3
wget~=3.2

23
tools/server/tests/tests.sh Executable file
View file

@ -0,0 +1,23 @@
#!/bin/bash
# make sure we are in the right directory
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
cd $SCRIPT_DIR
set -eu
if [[ "${SLOW_TESTS:-0}" == 1 ]]; then
# Slow tests for tool calls need quite a few models ahead of time to avoid timing out.
python $SCRIPT_DIR/../../../scripts/fetch_server_test_models.py
fi
if [ $# -lt 1 ]
then
if [[ "${SLOW_TESTS:-0}" == 1 ]]; then
pytest -v -x
else
pytest -v -x -m "not slow"
fi
else
pytest "$@"
fi

View file

@ -0,0 +1,96 @@
import pytest
import requests
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
def test_server_start_simple():
global server
server.start()
res = server.make_request("GET", "/health")
assert res.status_code == 200
def test_server_props():
global server
server.start()
res = server.make_request("GET", "/props")
assert res.status_code == 200
assert ".gguf" in res.body["model_path"]
assert res.body["total_slots"] == server.n_slots
default_val = res.body["default_generation_settings"]
assert server.n_ctx is not None and server.n_slots is not None
assert default_val["n_ctx"] == server.n_ctx / server.n_slots
assert default_val["params"]["seed"] == server.seed
def test_server_models():
global server
server.start()
res = server.make_request("GET", "/models")
assert res.status_code == 200
assert len(res.body["data"]) == 1
assert res.body["data"][0]["id"] == server.model_alias
def test_server_slots():
global server
# without slots endpoint enabled, this should return error
server.server_slots = False
server.start()
res = server.make_request("GET", "/slots")
assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED
assert "error" in res.body
server.stop()
# with slots endpoint enabled, this should return slots info
server.server_slots = True
server.n_slots = 2
server.start()
res = server.make_request("GET", "/slots")
assert res.status_code == 200
assert len(res.body) == server.n_slots
assert server.n_ctx is not None and server.n_slots is not None
assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
assert "params" in res.body[0]
assert res.body[0]["params"]["seed"] == server.seed
def test_load_split_model():
global server
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf"
server.model_alias = "tinyllama-split"
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 16,
"prompt": "Hello",
"temperature": 0.0,
})
assert res.status_code == 200
assert match_regex("(little|girl)+", res.body["content"])
def test_no_webui():
global server
# default: webui enabled
server.start()
url = f"http://{server.server_host}:{server.server_port}"
res = requests.get(url)
assert res.status_code == 200
assert "<html>" in res.text
server.stop()
# with --no-webui
server.no_webui = True
server.start()
res = requests.get(url)
assert res.status_code == 404

View file

@ -0,0 +1,311 @@
import pytest
from openai import OpenAI
from utils import *
server: ServerProcess
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
[
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
]
)
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
global server
server.jinja = jinja
server.chat_template = chat_template
server.start()
res = server.make_request("POST", "/chat/completions", data={
"model": model,
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
})
assert res.status_code == 200
assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
assert res.body["system_fingerprint"].startswith("b")
assert res.body["model"] == model if model is not None else server.model_alias
assert res.body["usage"]["prompt_tokens"] == n_prompt
assert res.body["usage"]["completion_tokens"] == n_predicted
choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
assert choice["finish_reason"] == finish_reason
@pytest.mark.parametrize(
"system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
[
("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
]
)
def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
global server
server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL
server.start()
res = server.make_stream_request("POST", "/chat/completions", data={
"max_tokens": max_tokens,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"stream": True,
})
content = ""
last_cmpl_id = None
for data in res:
choice = data["choices"][0]
assert data["system_fingerprint"].startswith("b")
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
if last_cmpl_id is None:
last_cmpl_id = data["id"]
assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
if choice["finish_reason"] in ["stop", "length"]:
assert data["usage"]["prompt_tokens"] == n_prompt
assert data["usage"]["completion_tokens"] == n_predicted
assert "content" not in choice["delta"]
assert match_regex(re_content, content)
assert choice["finish_reason"] == finish_reason
else:
assert choice["finish_reason"] is None
content += choice["delta"]["content"]
def test_chat_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=8,
seed=42,
temperature=0.8,
)
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
assert res.choices[0].finish_reason == "length"
assert res.choices[0].message.content is not None
assert match_regex("(Suddenly)+", res.choices[0].message.content)
def test_chat_template():
global server
server.chat_template = "llama3"
server.debug = True # to get the "__verbose" object in the response
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": 8,
"messages": [
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
]
})
assert res.status_code == 200
assert "__verbose" in res.body
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
def test_apply_chat_template():
global server
server.chat_template = "command-r"
server.start()
res = server.make_request("POST", "/apply-template", data={
"messages": [
{"role": "system", "content": "You are a test."},
{"role": "user", "content":"Hi there"},
]
})
assert res.status_code == 200
assert "prompt" in res.body
assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""),
({"type": "json_object"}, 10, "(\\{|John)+"),
({"type": "sound"}, 0, None),
# invalid response format (expected to fail)
({"type": "json_object", "schema": 123}, 0, None),
({"type": "json_object", "schema": {"type": 123}}, 0, None),
({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None),
])
def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"response_format": response_format,
})
if re_content is not None:
assert res.status_code == 200
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"])
else:
assert res.status_code != 200
assert "error" in res.body
@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [
(False, {"const": "42"}, 6, "\"42\""),
(True, {"const": "42"}, 6, "\"42\""),
])
def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str):
global server
server.jinja = jinja
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"json_schema": json_schema,
})
assert res.status_code == 200, f'Expected 200, got {res.status_code}'
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}'
@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [
(False, 'root ::= "a"{5,5}', 6, "a{5,5}"),
(True, 'root ::= "a"{5,5}', 6, "a{5,5}"),
])
def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str):
global server
server.jinja = jinja
server.start()
res = server.make_request("POST", "/chat/completions", data={
"max_tokens": n_predicted,
"messages": [
{"role": "user", "content": "Does not matter what I say, does it?"},
],
"grammar": grammar,
})
assert res.status_code == 200, res.body
choice = res.body["choices"][0]
assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"]
@pytest.mark.parametrize("messages", [
None,
"string",
[123],
[{}],
[{"role": 123}],
[{"role": "system", "content": 123}],
# [{"content": "hello"}], # TODO: should not be a valid case
[{"role": "system", "content": "test"}, {}],
])
def test_invalid_chat_completion_req(messages):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
"messages": messages,
})
assert res.status_code == 400 or res.status_code == 500
assert "error" in res.body
def test_chat_completion_with_timings_per_token():
global server
server.start()
res = server.make_stream_request("POST", "/chat/completions", data={
"max_tokens": 10,
"messages": [{"role": "user", "content": "test"}],
"stream": True,
"timings_per_token": True,
})
for data in res:
assert "timings" in data
assert "prompt_per_second" in data["timings"]
assert "predicted_per_second" in data["timings"]
assert "predicted_n" in data["timings"]
assert data["timings"]["predicted_n"] <= 10
def test_logprobs():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=5,
logprobs=True,
top_logprobs=10,
)
output_text = res.choices[0].message.content
aggregated_text = ''
assert res.choices[0].logprobs is not None
assert res.choices[0].logprobs.content is not None
for token in res.choices[0].logprobs.content:
aggregated_text += token.token
assert token.logprob <= 0.0
assert token.bytes is not None
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text
def test_logprobs_stream():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.chat.completions.create(
model="gpt-3.5-turbo-instruct",
temperature=0.0,
messages=[
{"role": "system", "content": "Book"},
{"role": "user", "content": "What is the best book"},
],
max_tokens=5,
logprobs=True,
top_logprobs=10,
stream=True,
)
output_text = ''
aggregated_text = ''
for data in res:
choice = data.choices[0]
if choice.finish_reason is None:
if choice.delta.content:
output_text += choice.delta.content
assert choice.logprobs is not None
assert choice.logprobs.content is not None
for token in choice.logprobs.content:
aggregated_text += token.token
assert token.logprob <= 0.0
assert token.bytes is not None
assert token.top_logprobs is not None
assert len(token.top_logprobs) > 0
assert aggregated_text == output_text

View file

@ -0,0 +1,428 @@
import pytest
import requests
import time
from openai import OpenAI
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
])
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": prompt,
"return_tokens": return_tokens,
})
assert res.status_code == 200
assert res.body["timings"]["prompt_n"] == n_prompt
assert res.body["timings"]["predicted_n"] == n_predicted
assert res.body["truncated"] == truncated
assert type(res.body["has_new_line"]) == bool
assert match_regex(re_content, res.body["content"])
if return_tokens:
assert len(res.body["tokens"]) > 0
assert all(type(tok) == int for tok in res.body["tokens"])
else:
assert res.body["tokens"] == []
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
])
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
global server
server.start()
res = server.make_stream_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": prompt,
"stream": True,
})
content = ""
for data in res:
assert "stop" in data and type(data["stop"]) == bool
if data["stop"]:
assert data["timings"]["prompt_n"] == n_prompt
assert data["timings"]["predicted_n"] == n_predicted
assert data["truncated"] == truncated
assert data["stop_type"] == "limit"
assert type(data["has_new_line"]) == bool
assert "generation_settings" in data
assert server.n_predict is not None
assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
assert data["generation_settings"]["seed"] == server.seed
assert match_regex(re_content, content)
else:
assert len(data["tokens"]) > 0
assert all(type(tok) == int for tok in data["tokens"])
content += data["content"]
def test_completion_stream_vs_non_stream():
global server
server.start()
res_stream = server.make_stream_request("POST", "/completion", data={
"n_predict": 8,
"prompt": "I believe the meaning of life is",
"stream": True,
})
res_non_stream = server.make_request("POST", "/completion", data={
"n_predict": 8,
"prompt": "I believe the meaning of life is",
})
content_stream = ""
for data in res_stream:
content_stream += data["content"]
assert content_stream == res_non_stream.body["content"]
def test_completion_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.completions.create(
model="davinci-002",
prompt="I believe the meaning of life is",
max_tokens=8,
)
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
assert res.choices[0].finish_reason == "length"
assert res.choices[0].text is not None
assert match_regex("(going|bed)+", res.choices[0].text)
def test_completion_stream_with_openai_library():
global server
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.completions.create(
model="davinci-002",
prompt="I believe the meaning of life is",
max_tokens=8,
stream=True,
)
output_text = ''
for data in res:
choice = data.choices[0]
if choice.finish_reason is None:
assert choice.text is not None
output_text += choice.text
assert match_regex("(going|bed)+", output_text)
@pytest.mark.parametrize("n_slots", [1, 2])
def test_consistent_result_same_seed(n_slots: int):
global server
server.n_slots = n_slots
server.start()
last_res = None
for _ in range(4):
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 0.0,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
if last_res is not None:
assert res.body["content"] == last_res.body["content"]
last_res = res
@pytest.mark.parametrize("n_slots", [1, 2])
def test_different_result_different_seed(n_slots: int):
global server
server.n_slots = n_slots
server.start()
last_res = None
for seed in range(4):
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": seed,
"temperature": 1.0,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
if last_res is not None:
assert res.body["content"] != last_res.body["content"]
last_res = res
# TODO figure why it don't work with temperature = 1
# @pytest.mark.parametrize("temperature", [0.0, 1.0])
@pytest.mark.parametrize("n_batch", [16, 32])
@pytest.mark.parametrize("temperature", [0.0])
def test_consistent_result_different_batch_size(n_batch: int, temperature: float):
global server
server.n_batch = n_batch
server.start()
last_res = None
for _ in range(4):
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": temperature,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
if last_res is not None:
assert res.body["content"] == last_res.body["content"]
last_res = res
@pytest.mark.skip(reason="This test fails on linux, need to be fixed")
def test_cache_vs_nocache_prompt():
global server
server.start()
res_cache = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 1.0,
"cache_prompt": True,
})
res_no_cache = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"seed": 42,
"temperature": 1.0,
"cache_prompt": False,
})
assert res_cache.body["content"] == res_no_cache.body["content"]
def test_completion_with_tokens_input():
global server
server.temperature = 0.0
server.start()
prompt_str = "I believe the meaning of life is"
res = server.make_request("POST", "/tokenize", data={
"content": prompt_str,
"add_special": True,
})
assert res.status_code == 200
tokens = res.body["tokens"]
# single completion
res = server.make_request("POST", "/completion", data={
"prompt": tokens,
})
assert res.status_code == 200
assert type(res.body["content"]) == str
# batch completion
res = server.make_request("POST", "/completion", data={
"prompt": [tokens, tokens],
})
assert res.status_code == 200
assert type(res.body) == list
assert len(res.body) == 2
assert res.body[0]["content"] == res.body[1]["content"]
# mixed string and tokens
res = server.make_request("POST", "/completion", data={
"prompt": [tokens, prompt_str],
})
assert res.status_code == 200
assert type(res.body) == list
assert len(res.body) == 2
assert res.body[0]["content"] == res.body[1]["content"]
# mixed string and tokens in one sequence
res = server.make_request("POST", "/completion", data={
"prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str],
})
assert res.status_code == 200
assert type(res.body["content"]) == str
@pytest.mark.parametrize("n_slots,n_requests", [
(1, 3),
(2, 2),
(2, 4),
(4, 2), # some slots must be idle
(4, 6),
])
def test_completion_parallel_slots(n_slots: int, n_requests: int):
global server
server.n_slots = n_slots
server.temperature = 0.0
server.start()
PROMPTS = [
("Write a very long book.", "(very|special|big)+"),
("Write another a poem.", "(small|house)+"),
("What is LLM?", "(Dad|said)+"),
("The sky is blue and I love it.", "(climb|leaf)+"),
("Write another very long music lyrics.", "(friends|step|sky)+"),
("Write a very long joke.", "(cat|Whiskers)+"),
]
def check_slots_status():
should_all_slots_busy = n_requests >= n_slots
time.sleep(0.1)
res = server.make_request("GET", "/slots")
n_busy = sum([1 for slot in res.body if slot["is_processing"]])
if should_all_slots_busy:
assert n_busy == n_slots
else:
assert n_busy <= n_slots
tasks = []
for i in range(n_requests):
prompt, re_content = PROMPTS[i % len(PROMPTS)]
tasks.append((server.make_request, ("POST", "/completion", {
"prompt": prompt,
"seed": 42,
"temperature": 1.0,
})))
tasks.append((check_slots_status, ()))
results = parallel_function_calls(tasks)
# check results
for i in range(n_requests):
prompt, re_content = PROMPTS[i % len(PROMPTS)]
res = results[i]
assert res.status_code == 200
assert type(res.body["content"]) == str
assert len(res.body["content"]) > 10
# FIXME: the result is not deterministic when using other slot than slot 0
# assert match_regex(re_content, res.body["content"])
@pytest.mark.parametrize(
"prompt,n_predict,response_fields",
[
("I believe the meaning of life is", 8, []),
("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]),
],
)
def test_completion_response_fields(
prompt: str, n_predict: int, response_fields: list[str]
):
global server
server.start()
res = server.make_request(
"POST",
"/completion",
data={
"n_predict": n_predict,
"prompt": prompt,
"response_fields": response_fields,
},
)
assert res.status_code == 200
assert "content" in res.body
assert len(res.body["content"])
if len(response_fields):
assert res.body["generation_settings/n_predict"] == n_predict
assert res.body["prompt"] == "<s> " + prompt
assert isinstance(res.body["content"], str)
assert len(res.body) == len(response_fields)
else:
assert len(res.body)
assert "generation_settings" in res.body
def test_n_probs():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"n_probs": 10,
"temperature": 0.0,
"n_predict": 5,
})
assert res.status_code == 200
assert "completion_probabilities" in res.body
assert len(res.body["completion_probabilities"]) == 5
for tok in res.body["completion_probabilities"]:
assert "id" in tok and tok["id"] > 0
assert "token" in tok and type(tok["token"]) == str
assert "logprob" in tok and tok["logprob"] <= 0.0
assert "bytes" in tok and type(tok["bytes"]) == list
assert len(tok["top_logprobs"]) == 10
for prob in tok["top_logprobs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "logprob" in prob and prob["logprob"] <= 0.0
assert "bytes" in prob and type(prob["bytes"]) == list
def test_n_probs_stream():
global server
server.start()
res = server.make_stream_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"n_probs": 10,
"temperature": 0.0,
"n_predict": 5,
"stream": True,
})
for data in res:
if data["stop"] == False:
assert "completion_probabilities" in data
assert len(data["completion_probabilities"]) == 1
for tok in data["completion_probabilities"]:
assert "id" in tok and tok["id"] > 0
assert "token" in tok and type(tok["token"]) == str
assert "logprob" in tok and tok["logprob"] <= 0.0
assert "bytes" in tok and type(tok["bytes"]) == list
assert len(tok["top_logprobs"]) == 10
for prob in tok["top_logprobs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "logprob" in prob and prob["logprob"] <= 0.0
assert "bytes" in prob and type(prob["bytes"]) == list
def test_n_probs_post_sampling():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"n_probs": 10,
"temperature": 0.0,
"n_predict": 5,
"post_sampling_probs": True,
})
assert res.status_code == 200
assert "completion_probabilities" in res.body
assert len(res.body["completion_probabilities"]) == 5
for tok in res.body["completion_probabilities"]:
assert "id" in tok and tok["id"] > 0
assert "token" in tok and type(tok["token"]) == str
assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
assert "bytes" in tok and type(tok["bytes"]) == list
assert len(tok["top_probs"]) == 10
for prob in tok["top_probs"]:
assert "id" in prob and prob["id"] > 0
assert "token" in prob and type(prob["token"]) == str
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
assert "bytes" in prob and type(prob["bytes"]) == list
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
def test_cancel_request():
global server
server.n_ctx = 4096
server.n_predict = -1
server.n_slots = 1
server.server_slots = True
server.start()
# send a request that will take a long time, but cancel it before it finishes
try:
server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
}, timeout=0.1)
except requests.exceptions.ReadTimeout:
pass # expected
# make sure the slot is free
time.sleep(1) # wait for HTTP_POLLING_SECONDS
res = server.make_request("GET", "/slots")
assert res.body[0]["is_processing"] == False

View file

@ -0,0 +1,67 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
LONG_TEXT = """
Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.
Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur.
Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.
""".strip()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.n_ctx = 256
server.n_slots = 2
def test_ctx_shift_enabled():
# the prompt is 301 tokens
# the slot context is 256/2 = 128 tokens
# the prompt is truncated to keep the last 109 tokens
# 64 tokens are generated thanks to shifting the context when it gets full
global server
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 64,
"prompt": LONG_TEXT,
})
assert res.status_code == 200
assert res.body["timings"]["prompt_n"] == 109
assert res.body["timings"]["predicted_n"] == 64
assert res.body["truncated"] is True
@pytest.mark.parametrize("n_predict,n_token_output,truncated", [
(64, 64, False),
(-1, 120, True),
])
def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool):
global server
server.disable_ctx_shift = True
server.n_predict = -1
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": n_predict,
"prompt": "Hi how are you",
})
assert res.status_code == 200
assert res.body["timings"]["predicted_n"] == n_token_output
assert res.body["truncated"] == truncated
def test_ctx_shift_disabled_long_prompt():
global server
server.disable_ctx_shift = True
server.start()
res = server.make_request("POST", "/completion", data={
"n_predict": 64,
"prompt": LONG_TEXT,
})
assert res.status_code != 200
assert "error" in res.body
assert "exceeds the available context size" in res.body["error"]["message"]

View file

@ -0,0 +1,257 @@
import base64
import struct
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.bert_bge_small()
EPSILON = 1e-3
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.bert_bge_small()
def test_embedding_single():
global server
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": "I believe the meaning of life is",
})
assert res.status_code == 200
assert len(res.body['data']) == 1
assert 'embedding' in res.body['data'][0]
assert len(res.body['data'][0]['embedding']) > 1
# make sure embedding vector is normalized
assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON
def test_embedding_multiple():
global server
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
"This is a test",
"This is another test",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 4
for d in res.body['data']:
assert 'embedding' in d
assert len(d['embedding']) > 1
def test_embedding_multiple_with_fa():
server = ServerPreset.bert_bge_small_with_fa()
server.pooling = 'last'
server.start()
# one of these should trigger the FA branch (i.e. context size % 256 == 0)
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"a "*253,
"b "*254,
"c "*255,
"d "*256,
],
})
assert res.status_code == 200
assert len(res.body['data']) == 4
for d in res.body['data']:
assert 'embedding' in d
assert len(d['embedding']) > 1
@pytest.mark.parametrize(
"input,is_multi_prompt",
[
# do not crash on empty input
("", False),
# single prompt
("string", False),
([12, 34, 56], False),
([12, 34, "string", 56, 78], False),
# multiple prompts
(["string1", "string2"], True),
(["string1", [12, 34, 56]], True),
([[12, 34, 56], [12, 34, 56]], True),
([[12, 34, 56], [12, "string", 34, 56]], True),
]
)
def test_embedding_mixed_input(input, is_multi_prompt: bool):
global server
server.start()
res = server.make_request("POST", "/v1/embeddings", data={"input": input})
assert res.status_code == 200
data = res.body['data']
if is_multi_prompt:
assert len(data) == len(input)
for d in data:
assert 'embedding' in d
assert len(d['embedding']) > 1
else:
assert 'embedding' in data[0]
assert len(data[0]['embedding']) > 1
def test_embedding_pooling_none():
global server
server.pooling = 'none'
server.start()
res = server.make_request("POST", "/embeddings", data={
"input": "hello hello hello",
})
assert res.status_code == 200
assert 'embedding' in res.body[0]
assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special
# make sure embedding vector is not normalized
for x in res.body[0]['embedding']:
assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON
def test_embedding_pooling_none_oai():
global server
server.pooling = 'none'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": "hello hello hello",
})
# /v1/embeddings does not support pooling type 'none'
assert res.status_code == 400
assert "error" in res.body
def test_embedding_openai_library_single():
global server
server.pooling = 'last'
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is")
assert len(res.data) == 1
assert len(res.data[0].embedding) > 1
def test_embedding_openai_library_multiple():
global server
server.pooling = 'last'
server.start()
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
res = client.embeddings.create(model="text-embedding-3-small", input=[
"I believe the meaning of life is",
"Write a joke about AI from a very long prompt which will not be truncated",
"This is a test",
"This is another test",
])
assert len(res.data) == 4
for d in res.data:
assert len(d.embedding) > 1
def test_embedding_error_prompt_too_long():
global server
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": "This is a test " * 512,
})
assert res.status_code != 200
assert "too large" in res.body["error"]["message"]
def test_same_prompt_give_same_result():
server.pooling = 'last'
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
"I believe the meaning of life is",
"I believe the meaning of life is",
"I believe the meaning of life is",
],
})
assert res.status_code == 200
assert len(res.body['data']) == 5
for i in range(1, len(res.body['data'])):
v0 = res.body['data'][0]['embedding']
vi = res.body['data'][i]['embedding']
for x, y in zip(v0, vi):
assert abs(x - y) < EPSILON
@pytest.mark.parametrize(
"content,n_tokens",
[
("I believe the meaning of life is", 9),
("This is a test", 6),
]
)
def test_embedding_usage_single(content, n_tokens):
global server
server.start()
res = server.make_request("POST", "/v1/embeddings", data={"input": content})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == n_tokens
def test_embedding_usage_multiple():
global server
server.start()
res = server.make_request("POST", "/v1/embeddings", data={
"input": [
"I believe the meaning of life is",
"I believe the meaning of life is",
],
})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == 2 * 9
def test_embedding_openai_library_base64():
server.start()
test_input = "Test base64 embedding output"
# get embedding in default format
res = server.make_request("POST", "/v1/embeddings", data={
"input": test_input
})
assert res.status_code == 200
vec0 = res.body["data"][0]["embedding"]
# get embedding in base64 format
res = server.make_request("POST", "/v1/embeddings", data={
"input": test_input,
"encoding_format": "base64"
})
assert res.status_code == 200
assert "data" in res.body
assert len(res.body["data"]) == 1
embedding_data = res.body["data"][0]
assert "embedding" in embedding_data
assert isinstance(embedding_data["embedding"], str)
# Verify embedding is valid base64
decoded = base64.b64decode(embedding_data["embedding"])
# Verify decoded data can be converted back to float array
float_count = len(decoded) // 4 # 4 bytes per float
floats = struct.unpack(f'{float_count}f', decoded)
assert len(floats) > 0
assert all(isinstance(x, float) for x in floats)
assert len(floats) == len(vec0)
# make sure the decoded data is the same as the original
for x, y in zip(floats, vec0):
assert abs(x - y) < EPSILON

View file

@ -0,0 +1,77 @@
import pytest
from utils import *
server = ServerPreset.tinyllama_infill()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama_infill()
def test_infill_without_input_extra():
global server
server.start()
res = server.make_request("POST", "/infill", data={
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"prompt": " int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 200
assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"])
def test_infill_with_input_extra():
global server
server.start()
res = server.make_request("POST", "/infill", data={
"input_extra": [{
"filename": "llama.h",
"text": "LLAMA_API int32_t llama_n_threads();\n"
}],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"prompt": " int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 200
assert match_regex("(Dad|excited|park)+", res.body["content"])
@pytest.mark.parametrize("input_extra", [
{},
{"filename": "ok"},
{"filename": 123},
{"filename": 123, "text": "abc"},
{"filename": 123, "text": 456},
])
def test_invalid_input_extra_req(input_extra):
global server
server.start()
res = server.make_request("POST", "/infill", data={
"input_extra": [input_extra],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"prompt": " int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 400
assert "error" in res.body
@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
def test_with_qwen_model():
global server
server.model_file = None
server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF"
server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf"
server.start(timeout_seconds=600)
res = server.make_request("POST", "/infill", data={
"input_extra": [{
"filename": "llama.h",
"text": "LLAMA_API int32_t llama_n_threads();\n"
}],
"input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n",
"prompt": " int n_threads = llama_",
"input_suffix": "}\n",
})
assert res.status_code == 200
assert res.body["content"] == "n_threads();\n printf(\"Number of threads: %d\\n\", n_threads);\n return 0;\n"

View file

@ -0,0 +1,115 @@
import pytest
from utils import *
server = ServerPreset.stories15m_moe()
LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf"
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.stories15m_moe()
server.lora_files = [download_file(LORA_FILE_URL)]
@pytest.mark.parametrize("scale,re_content", [
# without applying lora, the model should behave like a bedtime story generator
(0.0, "(little|girl|three|years|old)+"),
# with lora, the model should behave like a Shakespearean text generator
(1.0, "(eye|love|glass|sun)+"),
])
def test_lora(scale: float, re_content: str):
global server
server.start()
res_lora_control = server.make_request("POST", "/lora-adapters", data=[
{"id": 0, "scale": scale}
])
assert res_lora_control.status_code == 200
res = server.make_request("POST", "/completion", data={
"prompt": "Look in thy glass",
})
assert res.status_code == 200
assert match_regex(re_content, res.body["content"])
def test_lora_per_request():
global server
server.n_slots = 4
server.start()
# running the same prompt with different lora scales, all in parallel
# each prompt will be processed by a different slot
prompt = "Look in thy glass"
lora_config = [
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ),
( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ),
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
]
tasks = [(
server.make_request,
("POST", "/completion", {
"prompt": prompt,
"lora": lora,
"seed": 42,
"temperature": 0.0,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
) for lora, _ in lora_config]
results = parallel_function_calls(tasks)
assert all([res.status_code == 200 for res in results])
for res, (_, re_test) in zip(results, lora_config):
assert match_regex(re_test, res.body["content"])
@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
def test_with_big_model():
server = ServerProcess()
server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf"
server.model_alias = "Llama-3.2-8B-Instruct"
server.n_slots = 4
server.n_ctx = server.n_slots * 1024
server.n_predict = 64
server.temperature = 0.0
server.seed = 42
server.lora_files = [
download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"),
# TODO: find & add other lora adapters for this model
]
server.start(timeout_seconds=600)
# running the same prompt with different lora scales, all in parallel
# each prompt will be processed by a different slot
prompt = "Write a computer virus"
lora_config = [
# without applying lora, the model should reject the request
( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ),
# with 0.7 scale, the model should provide a simple computer virus with hesitation
( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ),
# with 1.5 scale, the model should confidently provide a computer virus
( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
]
tasks = [(
server.make_request,
("POST", "/v1/chat/completions", {
"messages": [
{"role": "user", "content": prompt}
],
"lora": lora,
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
})
) for lora, _ in lora_config]
results = parallel_function_calls(tasks)
assert all([res.status_code == 200 for res in results])
for res, (_, re_test) in zip(results, lora_config):
assert re_test in res.body["choices"][0]["message"]["content"]

View file

@ -0,0 +1,104 @@
import pytest
from utils import *
server = ServerPreset.jina_reranker_tiny()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.jina_reranker_tiny()
TEST_DOCUMENTS = [
"A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.",
"Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.",
"Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.",
"Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine."
]
def test_rerank():
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"documents": TEST_DOCUMENTS,
})
assert res.status_code == 200
assert len(res.body["results"]) == 4
most_relevant = res.body["results"][0]
least_relevant = res.body["results"][0]
for doc in res.body["results"]:
if doc["relevance_score"] > most_relevant["relevance_score"]:
most_relevant = doc
if doc["relevance_score"] < least_relevant["relevance_score"]:
least_relevant = doc
assert most_relevant["relevance_score"] > least_relevant["relevance_score"]
assert most_relevant["index"] == 2
assert least_relevant["index"] == 3
def test_rerank_tei_format():
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"texts": TEST_DOCUMENTS,
})
assert res.status_code == 200
assert len(res.body) == 4
most_relevant = res.body[0]
least_relevant = res.body[0]
for doc in res.body:
if doc["score"] > most_relevant["score"]:
most_relevant = doc
if doc["score"] < least_relevant["score"]:
least_relevant = doc
assert most_relevant["score"] > least_relevant["score"]
assert most_relevant["index"] == 2
assert least_relevant["index"] == 3
@pytest.mark.parametrize("documents", [
[],
None,
123,
[1, 2, 3],
])
def test_invalid_rerank_req(documents):
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": "Machine learning is",
"documents": documents,
})
assert res.status_code == 400
assert "error" in res.body
@pytest.mark.parametrize(
"query,doc1,doc2,n_tokens",
[
("Machine learning is", "A machine", "Learning is", 19),
("Which city?", "Machine learning is ", "Paris, capitale de la", 26),
]
)
def test_rerank_usage(query, doc1, doc2, n_tokens):
global server
server.start()
res = server.make_request("POST", "/rerank", data={
"query": query,
"documents": [
doc1,
doc2,
]
})
assert res.status_code == 200
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
assert res.body['usage']['prompt_tokens'] == n_tokens

View file

@ -0,0 +1,83 @@
import pytest
from openai import OpenAI
from utils import *
server = ServerPreset.tinyllama2()
TEST_API_KEY = "sk-this-is-the-secret-key"
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.api_key = TEST_API_KEY
@pytest.mark.parametrize("endpoint", ["/health", "/models"])
def test_access_public_endpoint(endpoint: str):
global server
server.start()
res = server.make_request("GET", endpoint)
assert res.status_code == 200
assert "error" not in res.body
@pytest.mark.parametrize("api_key", [None, "invalid-key"])
def test_incorrect_api_key(api_key: str):
global server
server.start()
res = server.make_request("POST", "/completions", data={
"prompt": "I believe the meaning of life is",
}, headers={
"Authorization": f"Bearer {api_key}" if api_key else None,
})
assert res.status_code == 401
assert "error" in res.body
assert res.body["error"]["type"] == "authentication_error"
def test_correct_api_key():
global server
server.start()
res = server.make_request("POST", "/completions", data={
"prompt": "I believe the meaning of life is",
}, headers={
"Authorization": f"Bearer {TEST_API_KEY}",
})
assert res.status_code == 200
assert "error" not in res.body
assert "content" in res.body
def test_openai_library_correct_api_key():
global server
server.start()
client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}")
res = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a chatbot."},
{"role": "user", "content": "What is the meaning of life?"},
],
)
assert len(res.choices) == 1
@pytest.mark.parametrize("origin,cors_header,cors_header_value", [
("localhost", "Access-Control-Allow-Origin", "localhost"),
("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"),
("origin", "Access-Control-Allow-Credentials", "true"),
("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"),
("web.mydomain.fr", "Access-Control-Allow-Headers", "*"),
])
def test_cors_options(origin: str, cors_header: str, cors_header_value: str):
global server
server.start()
res = server.make_request("OPTIONS", "/completions", headers={
"Origin": origin,
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "Authorization",
})
assert res.status_code == 200
assert cors_header in res.headers
assert res.headers[cors_header] == cors_header_value

View file

@ -0,0 +1,98 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.slot_save_path = "./tmp"
server.temperature = 0.0
def test_slot_save_restore():
global server
server.start()
# First prompt in slot 1 should be fully processed
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of France?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Whiskers|Flana)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
# Save state of slot 1
res = server.make_request("POST", "/slots/1?action=save", data={
"filename": "slot1.bin",
})
assert res.status_code == 200
assert res.body["n_saved"] == 84
# Since we have cache, this should only process the last tokens
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of Germany?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Jack|said)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
# Loading the saved cache into slot 0
res = server.make_request("POST", "/slots/0?action=restore", data={
"filename": "slot1.bin",
})
assert res.status_code == 200
assert res.body["n_restored"] == 84
# Since we have cache, slot 0 should only process the last tokens
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of Germany?",
"id_slot": 0,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Jack|said)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 6 # only different part is processed
# For verification that slot 1 was not corrupted during slot 0 load, same thing should work
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of Germany?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Jack|said)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 1
def test_slot_erase():
global server
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of France?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Whiskers|Flana)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed
# erase slot 1
res = server.make_request("POST", "/slots/1?action=erase")
assert res.status_code == 200
# re-run the same prompt, it should process all tokens again
res = server.make_request("POST", "/completion", data={
"prompt": "What is the capital of France?",
"id_slot": 1,
"cache_prompt": True,
})
assert res.status_code == 200
assert match_regex("(Whiskers|Flana)+", res.body["content"])
assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed

View file

@ -0,0 +1,126 @@
import pytest
from utils import *
# We use a F16 MOE gguf as main model, and q4_0 as draft model
server = ServerPreset.stories15m_moe()
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf"
def create_server():
global server
server = ServerPreset.stories15m_moe()
# set default values
server.model_draft = download_file(MODEL_DRAFT_FILE_URL)
server.draft_min = 4
server.draft_max = 8
@pytest.fixture(scope="module", autouse=True)
def fixture_create_server():
return create_server()
def test_with_and_without_draft():
global server
server.model_draft = None # disable draft model
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
})
assert res.status_code == 200
content_no_draft = res.body["content"]
server.stop()
# create new server with draft model
create_server()
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
})
assert res.status_code == 200
content_draft = res.body["content"]
assert content_no_draft == content_draft
def test_different_draft_min_draft_max():
global server
test_values = [
(1, 2),
(1, 4),
(4, 8),
(4, 12),
(8, 16),
]
last_content = None
for draft_min, draft_max in test_values:
server.stop()
server.draft_min = draft_min
server.draft_max = draft_max
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
})
assert res.status_code == 200
if last_content is not None:
assert last_content == res.body["content"]
last_content = res.body["content"]
def test_slot_ctx_not_exceeded():
global server
server.n_ctx = 64
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "Hello " * 56,
"temperature": 0.0,
"top_k": 1,
"speculative.p_min": 0.0,
})
assert res.status_code == 200
assert len(res.body["content"]) > 0
def test_with_ctx_shift():
global server
server.n_ctx = 64
server.start()
res = server.make_request("POST", "/completion", data={
"prompt": "Hello " * 56,
"temperature": 0.0,
"top_k": 1,
"n_predict": 64,
"speculative.p_min": 0.0,
})
assert res.status_code == 200
assert len(res.body["content"]) > 0
assert res.body["tokens_predicted"] == 64
assert res.body["truncated"] == True
@pytest.mark.parametrize("n_slots,n_requests", [
(1, 2),
(2, 2),
])
def test_multi_requests_parallel(n_slots: int, n_requests: int):
global server
server.n_slots = n_slots
server.start()
tasks = []
for _ in range(n_requests):
tasks.append((server.make_request, ("POST", "/completion", {
"prompt": "I believe the meaning of life is",
"temperature": 0.0,
"top_k": 1,
})))
results = parallel_function_calls(tasks)
for res in results:
assert res.status_code == 200
assert match_regex("(wise|kind|owl|answer)+", res.body["content"])

View file

@ -0,0 +1,59 @@
import pytest
from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="module", autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
def test_tokenize_detokenize():
global server
server.start()
# tokenize
content = "What is the capital of France ?"
res_tok = server.make_request("POST", "/tokenize", data={
"content": content
})
assert res_tok.status_code == 200
assert len(res_tok.body["tokens"]) > 5
# detokenize
res_detok = server.make_request("POST", "/detokenize", data={
"tokens": res_tok.body["tokens"],
})
assert res_detok.status_code == 200
assert res_detok.body["content"].strip() == content
def test_tokenize_with_bos():
global server
server.start()
# tokenize
content = "What is the capital of France ?"
bosId = 1
res_tok = server.make_request("POST", "/tokenize", data={
"content": content,
"add_special": True,
})
assert res_tok.status_code == 200
assert res_tok.body["tokens"][0] == bosId
def test_tokenize_with_pieces():
global server
server.start()
# tokenize
content = "This is a test string with unicode 媽 and emoji 🤗"
res_tok = server.make_request("POST", "/tokenize", data={
"content": content,
"with_pieces": True,
})
assert res_tok.status_code == 200
for token in res_tok.body["tokens"]:
assert "id" in token
assert token["id"] > 0
assert "piece" in token
assert len(token["piece"]) > 0

View file

@ -0,0 +1,606 @@
#!/usr/bin/env python
import pytest
# ensure grandparent path is in sys.path
from pathlib import Path
import sys
path = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(path))
from utils import *
server: ServerProcess
TIMEOUT_SERVER_START = 15*60
TIMEOUT_HTTP_REQUEST = 60
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.model_alias = "tinyllama-2-tool-call"
server.server_port = 8081
TEST_TOOL = {
"type":"function",
"function": {
"name": "test",
"description": "",
"parameters": {
"type": "object",
"properties": {
"success": {"type": "boolean", "const": True},
},
"required": ["success"]
}
}
}
PYTHON_TOOL = {
"type": "function",
"function": {
"name": "python",
"description": "Runs code in an ipython interpreter and returns the result of the execution after 60 seconds.",
"parameters": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "The code to run in the ipython interpreter."
}
},
"required": ["code"]
}
}
}
WEATHER_TOOL = {
"type":"function",
"function":{
"name":"get_current_weather",
"description":"Get the current weather in a given location",
"parameters":{
"type":"object",
"properties":{
"location":{
"type":"string",
"description":"The city and country/state, e.g. 'San Francisco, CA', or 'Paris, France'"
}
},
"required":["location"]
}
}
}
def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"tool_choice": "required",
"tools": [tool],
"parallel_tool_calls": False,
**kwargs,
})
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
assert expected_function_name == tool_call["function"]["name"]
actual_arguments = tool_call["function"]["arguments"]
assert isinstance(actual_arguments, str)
if argument_key is not None:
actual_arguments = json.loads(actual_arguments)
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
@pytest.mark.parametrize("template_name,tool,argument_key", [
("google-gemma-2-2b-it", TEST_TOOL, "success"),
("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"),
])
def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None):
global server
n_predict = 512
# server = ServerPreset.stories15m_moe()
server.jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, temperature=0.0, top_k=1, top_p=1.0)
@pytest.mark.slow
@pytest.mark.parametrize("template_name,tool,argument_key", [
("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"),
("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"),
("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"),
("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"),
("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"),
("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"),
("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"),
("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"),
("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"),
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"),
("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"),
("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"),
# ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"),
])
def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None):
global server
n_predict = 512
# server = ServerPreset.stories15m_moe()
server.jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict)
@pytest.mark.slow
@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [
(TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)),
(PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
(PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"),
(TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
])
def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
n_predict = 512
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192
server.n_predict = n_predict
server.model_hf_repo = hf_repo
server.model_hf_file = None
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"tool_choice": "required",
"tools": [tool],
"parallel_tool_calls": False,
"temperature": 0.0,
"top_k": 1,
"top_p": 1.0,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"]
assert expected_function_name == tool_call["function"]["name"]
actual_arguments = tool_call["function"]["arguments"]
assert isinstance(actual_arguments, str)
if argument_key is not None:
actual_arguments = json.loads(actual_arguments)
assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}"
def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "say hello world with python"},
],
"tools": tools if tools else None,
"tool_choice": tool_choice,
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
("meta-llama-Llama-3.3-70B-Instruct", 128, [], None),
("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None),
("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'),
])
def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
global server
server.jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
@pytest.mark.slow
@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [
("meetkai-functionary-medium-v3.2", 256, [], None),
("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None),
("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], 'none'),
("meetkai-functionary-medium-v3.1", 256, [], None),
("meetkai-functionary-medium-v3.1", 256, [TEST_TOOL], None),
("meetkai-functionary-medium-v3.1", 256, [PYTHON_TOOL], 'none'),
("meta-llama-Llama-3.2-3B-Instruct", 256, [], None),
("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None),
("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'),
])
def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
global server
server.jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice)
@pytest.mark.slow
@pytest.mark.parametrize("hf_repo,template_override", [
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")),
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
# Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it.
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
# ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)),
])
def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
n_predict = 512
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192
server.n_predict = n_predict
server.model_hf_repo = hf_repo
server.model_hf_file = None
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_weather(server, max_tokens=n_predict)
def do_test_weather(server: ServerProcess, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={
"messages": [
{"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."},
{"role": "user", "content": "What is the weather in Istanbul?"},
],
"tools": [WEATHER_TOOL],
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}'
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
actual_arguments = json.loads(tool_call["function"]["arguments"])
assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}"
location = actual_arguments["location"]
assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}"
assert re.match('^Istanbul(( |, ?)(TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}'
@pytest.mark.slow
@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None),
(None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
(None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")),
(None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)),
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
(None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
(None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
# TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value)
# (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
# ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
])
def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192 * 2
server.n_predict = n_predict
server.model_hf_repo = hf_repo
server.model_hf_file = None
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_calc_result(server, result_override, n_predict)
def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."},
{"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"},
{
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": "call_6789",
"type": "function",
"function": {
"name": "calculate",
"arguments": "{\"expression\":\"sin(30 * pi / 180)\"}"
}
}
]
},
{
"role": "tool",
"name": "calculate",
"content": "0.55644242476",
"tool_call_id": "call_6789"
}
],
"tools": [
{
"type":"function",
"function":{
"name":"calculate",
"description":"A calculator function that computes values of arithmetic expressions in the Python syntax",
"parameters":{
"type":"object",
"properties":{
"expression":{
"type":"string",
"description":"An arithmetic expression to compute the value of (Python syntad, assuming all floats)"
}
},
"required":["expression"]
}
}
}
],
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls is None, f'Expected no tool call in {choice["message"]}'
content = choice["message"].get("content")
assert content is not None, f'Expected content in {choice["message"]}'
if result_override is not None:
assert re.match(result_override, content), f'Expected {result_override}, got {content}'
else:
assert re.match('^[\\s\\S]*?((That\'s|\\bis) (approximately )?)?\\b0\\.(5\\b|56\\b|556)', content), \
f'Expected something like "The y coordinate is 0.56.", got {content}'
@pytest.mark.slow
@pytest.mark.parametrize("n_predict,reasoning_format,expect_content,expect_reasoning_content,hf_repo,template_override", [
(128, 'deepseek', "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(128, None, "^The sum of 102 and 7 is 109[\\s\\S]*", None, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "I need to calculate the sum of 102 and 7[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'none', "^(<think>\\s*)?I need[\\s\\S]*?</think>\\s*To find[\\s\\S]*", None, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
(1024, 'deepseek', "To find the sum of[\\s\\S]*", "First, I [\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)),
])
def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
server.n_slots = 1
server.reasoning_format = reasoning_format
server.jinja = True
server.n_ctx = 8192 * 2
server.n_predict = n_predict
server.model_hf_repo = hf_repo
server.model_hf_file = None
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
res = server.make_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
{"role": "user", "content": "What's the sum of 102 and 7?"},
]
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
content = choice["message"].get("content")
if expect_content is None:
assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
else:
assert re.match(expect_content, content), f'Expected {expect_content}, got {content}'
reasoning_content = choice["message"].get("reasoning_content")
if expect_reasoning_content is None:
assert reasoning_content is None, f'Expected no reasoning content in {choice["message"]}'
else:
assert re.match(expect_reasoning_content, reasoning_content), f'Expected {expect_reasoning_content}, got {reasoning_content}'
@pytest.mark.slow
@pytest.mark.parametrize("hf_repo,template_override", [
("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None),
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None),
("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)),
("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"),
# ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)),
("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None),
("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"),
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")),
("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")),
("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None),
("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"),
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None),
("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"),
])
def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None):
global server
n_predict = 512 # High because of DeepSeek R1
server.n_slots = 1
server.jinja = True
server.n_ctx = 8192
server.n_predict = n_predict
server.model_hf_repo = hf_repo
server.model_hf_file = None
if isinstance(template_override, tuple):
(template_hf_repo, template_variant) = template_override
server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja"
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
do_test_hello_world(server, max_tokens=n_predict)
def do_test_hello_world(server: ServerProcess, **kwargs):
res = server.make_request("POST", "/v1/chat/completions", data={
"messages": [
{"role": "system", "content": "You are a tool-calling agent."},
{"role": "user", "content": "say hello world with python"},
],
"tools": [PYTHON_TOOL],
**kwargs,
}, timeout=TIMEOUT_HTTP_REQUEST)
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
# assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}'
assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"]
assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}'
actual_arguments = json.loads(tool_call["function"]["arguments"])
assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}"
code = actual_arguments["code"]
assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}"
assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', code), f'Expected hello world, got {code}'

452
tools/server/tests/utils.py Normal file
View file

@ -0,0 +1,452 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# type: ignore[reportUnusedImport]
import subprocess
import os
import re
import json
import sys
import requests
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import (
Any,
Callable,
ContextManager,
Iterable,
Iterator,
List,
Literal,
Tuple,
Set,
)
from re import RegexFlag
import wget
DEFAULT_HTTP_TIMEOUT = 12
if "LLAMA_SANITIZE" in os.environ or "GITHUB_ACTION" in os.environ:
DEFAULT_HTTP_TIMEOUT = 30
class ServerResponse:
headers: dict
status_code: int
body: dict | Any
class ServerProcess:
# default options
debug: bool = False
server_port: int = 8080
server_host: str = "127.0.0.1"
model_hf_repo: str = "ggml-org/models"
model_hf_file: str | None = "tinyllamas/stories260K.gguf"
model_alias: str = "tinyllama-2"
temperature: float = 0.8
seed: int = 42
# custom options
model_alias: str | None = None
model_url: str | None = None
model_file: str | None = None
model_draft: str | None = None
n_threads: int | None = None
n_gpu_layer: int | None = None
n_batch: int | None = None
n_ubatch: int | None = None
n_ctx: int | None = None
n_ga: int | None = None
n_ga_w: int | None = None
n_predict: int | None = None
n_prompts: int | None = 0
slot_save_path: str | None = None
id_slot: int | None = None
cache_prompt: bool | None = None
n_slots: int | None = None
ctk: str | None = None
ctv: str | None = None
fa: bool | None = None
server_continuous_batching: bool | None = False
server_embeddings: bool | None = False
server_reranking: bool | None = False
server_metrics: bool | None = False
server_slots: bool | None = False
pooling: str | None = None
draft: int | None = None
api_key: str | None = None
lora_files: List[str] | None = None
disable_ctx_shift: int | None = False
draft_min: int | None = None
draft_max: int | None = None
no_webui: bool | None = None
jinja: bool | None = None
reasoning_format: Literal['deepseek', 'none'] | None = None
chat_template: str | None = None
chat_template_file: str | None = None
server_path: str | None = None
# session variables
process: subprocess.Popen | None = None
def __init__(self):
if "N_GPU_LAYERS" in os.environ:
self.n_gpu_layer = int(os.environ["N_GPU_LAYERS"])
if "DEBUG" in os.environ:
self.debug = True
if "PORT" in os.environ:
self.server_port = int(os.environ["PORT"])
def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
if self.server_path is not None:
server_path = self.server_path
elif "LLAMA_SERVER_BIN_PATH" in os.environ:
server_path = os.environ["LLAMA_SERVER_BIN_PATH"]
elif os.name == "nt":
server_path = "../../../build/bin/Release/llama-server.exe"
else:
server_path = "../../../build/bin/llama-server"
server_args = [
"--host",
self.server_host,
"--port",
self.server_port,
"--temp",
self.temperature,
"--seed",
self.seed,
]
if self.model_file:
server_args.extend(["--model", self.model_file])
if self.model_url:
server_args.extend(["--model-url", self.model_url])
if self.model_draft:
server_args.extend(["--model-draft", self.model_draft])
if self.model_hf_repo:
server_args.extend(["--hf-repo", self.model_hf_repo])
if self.model_hf_file:
server_args.extend(["--hf-file", self.model_hf_file])
if self.n_batch:
server_args.extend(["--batch-size", self.n_batch])
if self.n_ubatch:
server_args.extend(["--ubatch-size", self.n_ubatch])
if self.n_threads:
server_args.extend(["--threads", self.n_threads])
if self.n_gpu_layer:
server_args.extend(["--n-gpu-layers", self.n_gpu_layer])
if self.draft is not None:
server_args.extend(["--draft", self.draft])
if self.server_continuous_batching:
server_args.append("--cont-batching")
if self.server_embeddings:
server_args.append("--embedding")
if self.server_reranking:
server_args.append("--reranking")
if self.server_metrics:
server_args.append("--metrics")
if self.server_slots:
server_args.append("--slots")
if self.pooling:
server_args.extend(["--pooling", self.pooling])
if self.model_alias:
server_args.extend(["--alias", self.model_alias])
if self.n_ctx:
server_args.extend(["--ctx-size", self.n_ctx])
if self.n_slots:
server_args.extend(["--parallel", self.n_slots])
if self.ctk:
server_args.extend(["-ctk", self.ctk])
if self.ctv:
server_args.extend(["-ctv", self.ctv])
if self.fa is not None:
server_args.append("-fa")
if self.n_predict:
server_args.extend(["--n-predict", self.n_predict])
if self.slot_save_path:
server_args.extend(["--slot-save-path", self.slot_save_path])
if self.n_ga:
server_args.extend(["--grp-attn-n", self.n_ga])
if self.n_ga_w:
server_args.extend(["--grp-attn-w", self.n_ga_w])
if self.debug:
server_args.append("--verbose")
if self.lora_files:
for lora_file in self.lora_files:
server_args.extend(["--lora", lora_file])
if self.disable_ctx_shift:
server_args.extend(["--no-context-shift"])
if self.api_key:
server_args.extend(["--api-key", self.api_key])
if self.draft_max:
server_args.extend(["--draft-max", self.draft_max])
if self.draft_min:
server_args.extend(["--draft-min", self.draft_min])
if self.no_webui:
server_args.append("--no-webui")
if self.jinja:
server_args.append("--jinja")
if self.reasoning_format is not None:
server_args.extend(("--reasoning-format", self.reasoning_format))
if self.chat_template:
server_args.extend(["--chat-template", self.chat_template])
if self.chat_template_file:
server_args.extend(["--chat-template-file", self.chat_template_file])
args = [str(arg) for arg in [server_path, *server_args]]
print(f"tests: starting server with: {' '.join(args)}")
flags = 0
if "nt" == os.name:
flags |= subprocess.DETACHED_PROCESS
flags |= subprocess.CREATE_NEW_PROCESS_GROUP
flags |= subprocess.CREATE_NO_WINDOW
self.process = subprocess.Popen(
[str(arg) for arg in [server_path, *server_args]],
creationflags=flags,
stdout=sys.stdout,
stderr=sys.stdout,
env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None,
)
server_instances.add(self)
print(f"server pid={self.process.pid}, pytest pid={os.getpid()}")
# wait for server to start
start_time = time.time()
while time.time() - start_time < timeout_seconds:
try:
response = self.make_request("GET", "/health", headers={
"Authorization": f"Bearer {self.api_key}" if self.api_key else None
})
if response.status_code == 200:
self.ready = True
return # server is ready
except Exception as e:
pass
# Check if process died
if self.process.poll() is not None:
raise RuntimeError(f"Server process died with return code {self.process.returncode}")
print(f"Waiting for server to start...")
time.sleep(0.5)
raise TimeoutError(f"Server did not start within {timeout_seconds} seconds")
def stop(self) -> None:
if self in server_instances:
server_instances.remove(self)
if self.process:
print(f"Stopping server with pid={self.process.pid}")
self.process.kill()
self.process = None
def make_request(
self,
method: str,
path: str,
data: dict | Any | None = None,
headers: dict | None = None,
timeout: float | None = None,
) -> ServerResponse:
url = f"http://{self.server_host}:{self.server_port}{path}"
parse_body = False
if method == "GET":
response = requests.get(url, headers=headers, timeout=timeout)
parse_body = True
elif method == "POST":
response = requests.post(url, headers=headers, json=data, timeout=timeout)
parse_body = True
elif method == "OPTIONS":
response = requests.options(url, headers=headers, timeout=timeout)
else:
raise ValueError(f"Unimplemented method: {method}")
result = ServerResponse()
result.headers = dict(response.headers)
result.status_code = response.status_code
result.body = response.json() if parse_body else None
print("Response from server", json.dumps(result.body, indent=2))
return result
def make_stream_request(
self,
method: str,
path: str,
data: dict | None = None,
headers: dict | None = None,
) -> Iterator[dict]:
url = f"http://{self.server_host}:{self.server_port}{path}"
if method == "POST":
response = requests.post(url, headers=headers, json=data, stream=True)
else:
raise ValueError(f"Unimplemented method: {method}")
for line_bytes in response.iter_lines():
line = line_bytes.decode("utf-8")
if '[DONE]' in line:
break
elif line.startswith('data: '):
data = json.loads(line[6:])
print("Partial response from server", json.dumps(data, indent=2))
yield data
server_instances: Set[ServerProcess] = set()
class ServerPreset:
@staticmethod
def tinyllama2() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/stories260K.gguf"
server.model_alias = "tinyllama-2"
server.n_ctx = 512
server.n_batch = 32
server.n_slots = 2
server.n_predict = 64
server.seed = 42
return server
@staticmethod
def bert_bge_small() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
server.model_alias = "bert-bge-small"
server.n_ctx = 512
server.n_batch = 128
server.n_ubatch = 128
server.n_slots = 2
server.seed = 42
server.server_embeddings = True
return server
@staticmethod
def bert_bge_small_with_fa() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf"
server.model_alias = "bert-bge-small"
server.n_ctx = 1024
server.n_batch = 300
server.n_ubatch = 300
server.n_slots = 2
server.fa = True
server.seed = 42
server.server_embeddings = True
return server
@staticmethod
def tinyllama_infill() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "tinyllamas/stories260K-infill.gguf"
server.model_alias = "tinyllama-infill"
server.n_ctx = 2048
server.n_batch = 1024
server.n_slots = 1
server.n_predict = 64
server.temperature = 0.0
server.seed = 42
return server
@staticmethod
def stories15m_moe() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/stories15M_MOE"
server.model_hf_file = "stories15M_MOE-F16.gguf"
server.model_alias = "stories15m-moe"
server.n_ctx = 2048
server.n_batch = 1024
server.n_slots = 1
server.n_predict = 64
server.temperature = 0.0
server.seed = 42
return server
@staticmethod
def jina_reranker_tiny() -> ServerProcess:
server = ServerProcess()
server.model_hf_repo = "ggml-org/models"
server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf"
server.model_alias = "jina-reranker"
server.n_ctx = 512
server.n_batch = 512
server.n_slots = 1
server.seed = 42
server.server_reranking = True
return server
def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]:
"""
Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS.
Example usage:
results = parallel_function_calls([
(func1, (arg1, arg2)),
(func2, (arg3, arg4)),
])
"""
results = [None] * len(function_list)
exceptions = []
def worker(index, func, args):
try:
result = func(*args)
results[index] = result
except Exception as e:
exceptions.append((index, str(e)))
with ThreadPoolExecutor() as executor:
futures = []
for i, (func, args) in enumerate(function_list):
future = executor.submit(worker, i, func, args)
futures.append(future)
# Wait for all futures to complete
for future in as_completed(futures):
pass
# Check if there were any exceptions
if exceptions:
print("Exceptions occurred:")
for index, error in exceptions:
print(f"Function at index {index}: {error}")
return results
def match_regex(regex: str, text: str) -> bool:
return (
re.compile(
regex, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL
).search(text)
is not None
)
def download_file(url: str, output_file_path: str | None = None) -> str:
"""
Download a file from a URL to a local path. If the file already exists, it will not be downloaded again.
output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory.
Returns the local path of the downloaded file.
"""
file_name = url.split('/').pop()
output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path
if not os.path.exists(output_file):
print(f"Downloading {url} to {output_file}")
wget.download(url, out=output_file)
print(f"Done downloading to {output_file}")
else:
print(f"File already exists at {output_file}")
return output_file
def is_slow_test_allowed():
return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"

View file

@ -0,0 +1,5 @@
# LLaMA.cpp Server Wild Theme
Simple themes directory of sample "public" directories. To try any of these add --path to your run like `server --path=wild`.
![image](wild/wild.png)

View file

@ -0,0 +1,7 @@
# LLaMA.cpp Server Buttons Top Theme
Simple tweaks to the UI. Chat buttons at the top of the page instead of bottom so you can hit Stop instead of chasing it down the page.
To use simply run server with `--path=themes/buttons_top`
![image](buttons_top.png)

Binary file not shown.

After

Width:  |  Height:  |  Size: 117 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4 KiB

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,5 @@
# LLaMA.cpp Server Wild Theme
Simple tweaks to the UI. To use simply run server with `--path=themes/wild`
![image](wild.png)

Binary file not shown.

After

Width:  |  Height:  |  Size: 4 KiB

File diff suppressed because it is too large Load diff

Binary file not shown.

After

Width:  |  Height:  |  Size: 75 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 254 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 485 KiB

937
tools/server/utils.hpp Normal file
View file

@ -0,0 +1,937 @@
#pragma once
#include "common.h"
#include "log.h"
#include "llama.h"
#include "base64.hpp"
// increase max payload length to allow use of larger context size
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
// disable Nagle's algorithm
#define CPPHTTPLIB_TCP_NODELAY true
#include "httplib.h"
// Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT
#include "json.hpp"
#include "chat.h"
#include <random>
#include <sstream>
#include <string>
#include <vector>
#include <memory>
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo"
using json = nlohmann::ordered_json;
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
template <typename T>
static T json_value(const json & body, const std::string & key, const T & default_value) {
// Fallback null to default value
if (body.contains(key) && !body.at(key).is_null()) {
try {
return body.at(key);
} catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) {
LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(), json(default_value).type_name());
return default_value;
}
} else {
return default_value;
}
}
const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT);
// thin wrapper around common_grammar_trigger with (de)serialization functions
struct server_grammar_trigger {
common_grammar_trigger value;
server_grammar_trigger() = default;
server_grammar_trigger(const common_grammar_trigger & value) : value(value) {}
server_grammar_trigger(const json & in) {
value.type = (common_grammar_trigger_type) in.at("type").get<int>();
value.value = in.at("value").get<std::string>();
if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
value.token = (llama_token) in.at("token").get<int>();
}
}
json to_json() const {
json out {
{"type", (int) value.type},
{"value", value.value},
};
if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) {
out["token"] = (int) value.token;
}
return out;
}
};
//
// tokenizer and input processing utils
//
static bool json_is_array_of_numbers(const json & data) {
if (data.is_array()) {
for (const auto & e : data) {
if (!e.is_number_integer()) {
return false;
}
}
return true;
}
return false;
}
// is array having BOTH numbers & strings?
static bool json_is_array_of_mixed_numbers_strings(const json & data) {
bool seen_string = false;
bool seen_number = false;
if (data.is_array()) {
for (const auto & e : data) {
seen_string |= e.is_string();
seen_number |= e.is_number_integer();
if (seen_number && seen_string) {
return true;
}
}
}
return false;
}
// get value by path(key1 / key2)
static json json_get_nested_values(const std::vector<std::string> & paths, const json & js) {
json result = json::object();
for (const std::string & path : paths) {
json current = js;
const auto keys = string_split<std::string>(path, /*separator*/ '/');
bool valid_path = true;
for (const std::string & k : keys) {
if (valid_path && current.is_object() && current.contains(k)) {
current = current[k];
} else {
valid_path = false;
}
}
if (valid_path) {
result[path] = current;
}
}
return result;
}
/**
* this handles 2 cases:
* - only string, example: "string"
* - mixed string and tokens, example: [12, 34, "string", 56, 78]
*/
static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string.
llama_tokens prompt_tokens;
if (json_prompt.is_array()) {
bool first = true;
for (const auto & p : json_prompt) {
if (p.is_string()) {
auto s = p.template get<std::string>();
llama_tokens p;
if (first) {
p = common_tokenize(vocab, s, add_special, parse_special);
first = false;
} else {
p = common_tokenize(vocab, s, false, parse_special);
}
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
} else {
if (first) {
first = false;
}
prompt_tokens.push_back(p.template get<llama_token>());
}
}
} else {
auto s = json_prompt.template get<std::string>();
prompt_tokens = common_tokenize(vocab, s, add_special, parse_special);
}
return prompt_tokens;
}
/**
* break the input "prompt" object into multiple prompt if needed, then tokenize them
* this supports these cases:
* - "prompt": "string"
* - "prompt": [12, 34, 56]
* - "prompt": [12, 34, "string", 56, 78]
* and multiple prompts (multi-tasks):
* - "prompt": ["string1", "string2"]
* - "prompt": ["string1", [12, 34, 56]]
* - "prompt": [[12, 34, 56], [78, 90, 12]]
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
*/
static std::vector<llama_tokens> tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
std::vector<llama_tokens> result;
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
// string or mixed
result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special));
} else if (json_is_array_of_numbers(json_prompt)) {
// array of tokens
result.push_back(json_prompt.get<llama_tokens>());
} else if (json_prompt.is_array()) {
// array of prompts
result.reserve(json_prompt.size());
for (const auto & p : json_prompt) {
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
result.push_back(tokenize_mixed(vocab, p, add_special, parse_special));
} else if (json_is_array_of_numbers(p)) {
// array of tokens
result.push_back(p.get<llama_tokens>());
} else {
throw std::runtime_error("element of \"prompt\" must be a string, an list of tokens, or a list of mixed strings & tokens");
}
}
} else {
throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
}
if (result.empty()) {
throw std::runtime_error("\"prompt\" must not be empty");
}
return result;
}
// return the last index of character that can form a valid string
// if the last character is potentially cut in half, return the index before the cut
// if validate_utf8(text) == text.size(), then the whole text is valid utf8
static size_t validate_utf8(const std::string& text) {
size_t len = text.size();
if (len == 0) return 0;
// Check the last few bytes to see if a multi-byte character is cut off
for (size_t i = 1; i <= 4 && i <= len; ++i) {
unsigned char c = text[len - i];
// Check for start of a multi-byte sequence from the end
if ((c & 0xE0) == 0xC0) {
// 2-byte character start: 110xxxxx
// Needs at least 2 bytes
if (i < 2) return len - i;
} else if ((c & 0xF0) == 0xE0) {
// 3-byte character start: 1110xxxx
// Needs at least 3 bytes
if (i < 3) return len - i;
} else if ((c & 0xF8) == 0xF0) {
// 4-byte character start: 11110xxx
// Needs at least 4 bytes
if (i < 4) return len - i;
}
}
// If no cut-off multi-byte character is found, return full length
return len;
}
//
// template utils
//
// format rerank task: [BOS]query[EOS][SEP]doc[EOS]
static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
llama_tokens result;
result.reserve(doc.size() + query.size() + 4);
result.push_back(llama_vocab_bos(vocab));
result.insert(result.end(), query.begin(), query.end());
result.push_back(llama_vocab_eos(vocab));
result.push_back(llama_vocab_sep(vocab));
result.insert(result.end(), doc.begin(), doc.end());
result.push_back(llama_vocab_eos(vocab));
return result;
}
// format infill task
static llama_tokens format_infill(
const llama_vocab * vocab,
const json & input_prefix,
const json & input_suffix,
const json & input_extra,
const int n_batch,
const int n_predict,
const int n_ctx,
const bool spm_infill,
const llama_tokens & tokens_prompt
) {
// TODO: optimize this block by reducing memory allocations and movement
// use FIM repo-level pattern:
// ref: https://arxiv.org/pdf/2409.12186
//
// [FIM_REP]myproject
// [FIM_SEP]filename0
// extra chunk 0
// [FIM_SEP]filename1
// extra chunk 1
// ...
// [FIM_SEP]filename
// [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt
//
llama_tokens extra_tokens;
extra_tokens.reserve(n_ctx);
auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false);
auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false);
if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) {
// TODO: make project name an input
static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false);
extra_tokens.push_back(llama_vocab_fim_rep(vocab));
extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
}
for (const auto & chunk : input_extra) {
// { "text": string, "filename": string }
const std::string text = json_value(chunk, "text", std::string());
const std::string filename = json_value(chunk, "filename", std::string("tmp"));
if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false);
extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab));
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} else {
// chunk separator in binary form to avoid confusing the AI
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false);
extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
}
const auto chunk_tokens = common_tokenize(vocab, text, false, false);
extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
}
if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
// TODO: current filename
static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false);
extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab));
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
}
// for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
const int n_prefix_take = std::min<int>(tokens_prefix.size(), 3*(n_batch/4));
const int n_suffix_take = std::min<int>(tokens_suffix.size(), std::max<int>(0, (n_batch/4) - (2 + tokens_prompt.size())));
SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take));
// fill the rest of the context with extra chunks
const int n_extra_take = std::min<int>(std::max<int>(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size());
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
tokens_suffix.resize(n_suffix_take);
tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab));
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab));
auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
if (llama_vocab_get_add_bos(vocab)) {
embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
}
SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
// put the extra context before the FIM prefix
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
embd_inp.push_back(llama_vocab_fim_mid(vocab));
return embd_inp;
}
//
// base64 utils (TODO: move to common in the future)
//
static const std::string base64_chars =
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789+/";
static inline bool is_base64(uint8_t c) {
return (isalnum(c) || (c == '+') || (c == '/'));
}
static inline std::vector<uint8_t> base64_decode(const std::string & encoded_string) {
int i = 0;
int j = 0;
int in_ = 0;
int in_len = encoded_string.size();
uint8_t char_array_4[4];
uint8_t char_array_3[3];
std::vector<uint8_t> ret;
while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
char_array_4[i++] = encoded_string[in_]; in_++;
if (i == 4) {
for (i = 0; i < 4; i++) {
char_array_4[i] = base64_chars.find(char_array_4[i]);
}
char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (i = 0; (i < 3); i++) {
ret.push_back(char_array_3[i]);
}
i = 0;
}
}
if (i) {
for (j = i; j < 4; j++) {
char_array_4[j] = 0;
}
for (j = 0; j < 4; j++) {
char_array_4[j] = base64_chars.find(char_array_4[j]);
}
char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4);
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
for (j = 0; j < i - 1; j++) {
ret.push_back(char_array_3[j]);
}
}
return ret;
}
//
// random string / id
//
static std::string random_string() {
static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
std::random_device rd;
std::mt19937 generator(rd());
std::string result(32, ' ');
for (int i = 0; i < 32; ++i) {
result[i] = str[generator() % str.size()];
}
return result;
}
static std::string gen_chatcmplid() {
return "chatcmpl-" + random_string();
}
static std::string gen_tool_call_id() {
return random_string();
}
//
// other common utils
//
static bool ends_with(const std::string & str, const std::string & suffix) {
return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}
static size_t find_partial_stop_string(const std::string &stop, const std::string &text) {
if (!text.empty() && !stop.empty()) {
const char text_last_char = text.back();
for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) {
if (stop[char_index] == text_last_char) {
const std::string current_partial = stop.substr(0, char_index + 1);
if (ends_with(text, current_partial)) {
return text.size() - char_index - 1;
}
}
}
}
return std::string::npos;
}
// TODO: reuse llama_detokenize
template <class Iter>
static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
std::string ret;
for (; begin != end; ++begin) {
ret += common_token_to_piece(ctx, *begin);
}
return ret;
}
// format incomplete utf-8 multibyte character for output
static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token);
// if the size is 1 and first bit is 1, meaning it's a partial character
// (size > 1 meaning it's already a known token)
if (out.size() == 1 && (out[0] & 0x80) == 0x80) {
std::stringstream ss;
ss << std::hex << (out[0] & 0xff);
std::string res(ss.str());
out = "byte: \\x" + res;
}
return out;
}
static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) {
const std::string str =
std::string(event) + ": " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
LOG_DBG("data stream, to_send: %s", str.c_str());
return sink.write(str.c_str(), str.size());
}
//
// OAI utils
//
static json oaicompat_completion_params_parse(const json & body) {
json llama_params;
if (!body.contains("prompt")) {
throw std::runtime_error("\"prompt\" is required");
}
// Handle "stop" field
if (body.contains("stop") && body.at("stop").is_string()) {
llama_params["stop"] = json::array({body.at("stop").get<std::string>()});
} else {
llama_params["stop"] = json_value(body, "stop", json::array());
}
// Handle "n" field
int n_choices = json_value(body, "n", 1);
if (n_choices != 1) {
throw std::runtime_error("Only one completion choice is allowed");
}
// Handle "echo" field
if (json_value(body, "echo", false)) {
throw std::runtime_error("Only no echo is supported");
}
// Params supported by OAI but unsupported by llama.cpp
static const std::vector<std::string> unsupported_params { "best_of", "suffix" };
for (const auto & param : unsupported_params) {
if (body.contains(param)) {
throw std::runtime_error("Unsupported param: " + param);
}
}
// Copy remaining properties to llama_params
for (const auto & item : body.items()) {
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
llama_params[item.key()] = item.value();
}
}
return llama_params;
}
static json oaicompat_completion_params_parse(
const json & body, /* openai api json semantics */
bool use_jinja,
common_reasoning_format reasoning_format,
const struct common_chat_templates * tmpls)
{
json llama_params;
auto tools = json_value(body, "tools", json());
auto stream = json_value(body, "stream", false);
if (tools.is_array() && !tools.empty()) {
if (stream) {
throw std::runtime_error("Cannot use tools with stream");
}
if (!use_jinja) {
throw std::runtime_error("tools param requires --jinja flag");
}
}
if (!use_jinja) {
if (body.contains("tool_choice") && !body.at("tool_choice").is_null()) {
throw std::runtime_error("Unsupported param: tool_choice");
}
}
// Handle "stop" field
if (body.contains("stop") && body.at("stop").is_string()) {
llama_params["stop"] = json::array({body.at("stop").get<std::string>()});
} else {
llama_params["stop"] = json_value(body, "stop", json::array());
}
auto json_schema = json_value(body, "json_schema", json());
auto grammar = json_value(body, "grammar", std::string());
if (!json_schema.is_null() && !grammar.empty()) {
throw std::runtime_error("Cannot use both json_schema and grammar");
}
// Handle "response_format" field
if (body.contains("response_format")) {
json response_format = json_value(body, "response_format", json::object());
std::string response_type = json_value(response_format, "type", std::string());
if (response_type == "json_object") {
json_schema = json_value(response_format, "schema", json::object());
} else if (response_type == "json_schema") {
auto schema_wrapper = json_value(response_format, "json_schema", json::object());
json_schema = json_value(schema_wrapper, "schema", json::object());
} else if (!response_type.empty() && response_type != "text") {
throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
}
}
common_chat_templates_inputs inputs;
inputs.messages = common_chat_msgs_parse_oaicompat(body.at("messages"));
inputs.tools = common_chat_tools_parse_oaicompat(tools);
inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(json_value(body, "tool_choice", std::string("auto")));
inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump();
inputs.grammar = grammar;
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
inputs.use_jinja = use_jinja;
inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false);
inputs.extract_reasoning = reasoning_format != COMMON_REASONING_FORMAT_NONE;
inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true);
if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && body.contains("grammar")) {
throw std::runtime_error("Cannot use custom grammar constraints with tools.");
}
// if the assistant message appears at the end of list, we do not add end-of-turn token
// for ex. this can be useful to modify the reasoning process in reasoning models
bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant";
common_chat_msg last_message;
if (prefill_assistant_message) {
last_message = inputs.messages.back();
inputs.messages.pop_back();
/* sanity check, max one assistant message at the end of the list */
if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){
throw std::runtime_error("Cannot have 2 or more assistant messages at the end of the list.");
}
inputs.extract_reasoning = false;
inputs.add_generation_prompt = true;
}
// Apply chat template to the list of messages
auto chat_params = common_chat_templates_apply(tmpls, inputs);
/* Append assistant prefilled message */
if (prefill_assistant_message) {
chat_params.prompt += last_message.content;
}
llama_params["chat_format"] = static_cast<int>(chat_params.format);
llama_params["prompt"] = chat_params.prompt;
if (!chat_params.grammar.empty()) {
llama_params["grammar"] = chat_params.grammar;
}
llama_params["grammar_lazy"] = chat_params.grammar_lazy;
auto grammar_triggers = json::array();
for (const auto & trigger : chat_params.grammar_triggers) {
server_grammar_trigger ct(trigger);
grammar_triggers.push_back(ct.to_json());
}
llama_params["grammar_triggers"] = grammar_triggers;
llama_params["preserved_tokens"] = chat_params.preserved_tokens;
for (const auto & stop : chat_params.additional_stops) {
llama_params["stop"].push_back(stop);
}
// Handle "n" field
int n_choices = json_value(body, "n", 1);
if (n_choices != 1) {
throw std::runtime_error("Only one completion choice is allowed");
}
// Handle "logprobs" field
// TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future
if (json_value(body, "logprobs", false)) {
llama_params["n_probs"] = json_value(body, "top_logprobs", 20);
} else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) {
throw std::runtime_error("top_logprobs requires logprobs to be set to true");
}
// Copy remaining properties to llama_params
// This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint.
// See "launch_slot_with_task()" for a complete list of params supported by llama.cpp
for (const auto & item : body.items()) {
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
llama_params[item.key()] = item.value();
}
}
return llama_params;
}
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false) {
json data = json::array();
int32_t n_tokens = 0;
int i = 0;
for (const auto & elem : embeddings) {
json embedding_obj;
if (use_base64) {
const auto& vec = json_value(elem, "embedding", json::array()).get<std::vector<float>>();
const char* data_ptr = reinterpret_cast<const char*>(vec.data());
size_t data_size = vec.size() * sizeof(float);
embedding_obj = {
{"embedding", base64::encode(data_ptr, data_size)},
{"index", i++},
{"object", "embedding"},
{"encoding_format", "base64"}
};
} else {
embedding_obj = {
{"embedding", json_value(elem, "embedding", json::array())},
{"index", i++},
{"object", "embedding"}
};
}
data.push_back(embedding_obj);
n_tokens += json_value(elem, "tokens_evaluated", 0);
}
json res = json {
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json {
{"prompt_tokens", n_tokens},
{"total_tokens", n_tokens}
}},
{"data", data}
};
return res;
}
static json format_response_rerank(
const json & request,
const json & ranks,
bool is_tei_format,
std::vector<std::string> & texts) {
json res;
if (is_tei_format) {
// TEI response format
res = json::array();
bool return_text = json_value(request, "return_text", false);
for (const auto & rank : ranks) {
int index = json_value(rank, "index", 0);
json elem = json{
{"index", index},
{"score", json_value(rank, "score", 0.0)},
};
if (return_text) {
elem["text"] = std::move(texts[index]);
}
res.push_back(elem);
}
} else {
// Jina response format
json results = json::array();
int32_t n_tokens = 0;
for (const auto & rank : ranks) {
results.push_back(json{
{"index", json_value(rank, "index", 0)},
{"relevance_score", json_value(rank, "score", 0.0)},
});
n_tokens += json_value(rank, "tokens_evaluated", 0);
}
res = json{
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", "list"},
{"usage", json{
{"prompt_tokens", n_tokens},
{"total_tokens", n_tokens}
}},
{"results", results}
};
}
return res;
}
static bool is_valid_utf8(const std::string & str) {
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
const unsigned char* end = bytes + str.length();
while (bytes < end) {
if (*bytes <= 0x7F) {
// 1-byte sequence (0xxxxxxx)
bytes++;
} else if ((*bytes & 0xE0) == 0xC0) {
// 2-byte sequence (110xxxxx 10xxxxxx)
if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80)
return false;
bytes += 2;
} else if ((*bytes & 0xF0) == 0xE0) {
// 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx)
if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80)
return false;
bytes += 3;
} else if ((*bytes & 0xF8) == 0xF0) {
// 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx)
if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 ||
(bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80)
return false;
bytes += 4;
} else {
// Invalid UTF-8 lead byte
return false;
}
}
return true;
}
static json format_tokenizer_response(const json & tokens) {
return json {
{"tokens", tokens}
};
}
static json format_detokenized_response(const std::string & content) {
return json {
{"content", content}
};
}
static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) {
json data = json::array();
for (const auto & lb : logit_bias) {
data.push_back(json{
{"bias", lb.bias},
{"token", lb.token},
});
}
return data;
}
static std::string safe_json_to_str(const json & data) {
return data.dump(-1, ' ', false, json::error_handler_t::replace);
}
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
std::vector<llama_token_data> cur;
const auto * logits = llama_get_logits_ith(ctx, idx);
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_vocab_n_tokens(vocab);
cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
// sort tokens by logits
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
});
// apply softmax
float max_l = cur[0].logit;
float cum_sum = 0.0f;
for (size_t i = 0; i < cur.size(); ++i) {
float p = expf(cur[i].logit - max_l);
cur[i].p = p;
cum_sum += p;
}
for (size_t i = 0; i < cur.size(); ++i) {
cur[i].p /= cum_sum;
}
return cur;
}
static bool are_lora_equal(
const std::vector<common_adapter_lora_info> & l1,
const std::vector<common_adapter_lora_info> & l2) {
if (l1.size() != l2.size()) {
return false;
}
for (size_t i = 0; i < l1.size(); ++i) {
// we don't check lora.path to reduce the time complexity
if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) {
return false;
}
}
return true;
}
// parse lora config from JSON request, returned a copy of lora_base with updated scale
static std::vector<common_adapter_lora_info> parse_lora_request(
const std::vector<common_adapter_lora_info> & lora_base,
const json & data) {
std::vector<common_adapter_lora_info> lora(lora_base);
int max_idx = lora.size();
// clear existing value
for (auto & entry : lora) {
entry.scale = 0.0f;
}
// set value
for (const auto & entry : data) {
int id = json_value(entry, "id", -1);
float scale = json_value(entry, "scale", 0.0f);
if (0 <= id && id < max_idx) {
lora[id].scale = scale;
} else {
throw std::runtime_error("invalid adapter id");
}
}
return lora;
}

24
tools/server/webui/.gitignore vendored Normal file
View file

@ -0,0 +1,24 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?

View file

@ -0,0 +1,10 @@
**/.vscode
**/.github
**/.git
**/.svn
**/.hg
**/node_modules
**/dist
**/build
*.config.js

View file

@ -0,0 +1,26 @@
import js from '@eslint/js'
import globals from 'globals'
import reactHooks from 'eslint-plugin-react-hooks'
import reactRefresh from 'eslint-plugin-react-refresh'
import tseslint from 'typescript-eslint'
export default tseslint.config(
{ ignores: ['dist'] },
{
extends: [js.configs.recommended, ...tseslint.configs.recommended],
files: ['**/*.{ts,tsx}'],
languageOptions: {
ecmaVersion: 2020,
globals: globals.browser,
},
plugins: {
'react-hooks': reactHooks,
'react-refresh': reactRefresh,
},
rules: {
...reactHooks.configs.recommended.rules,
'react-refresh/only-export-components': 'off',
'@typescript-eslint/no-unused-vars': 'off',
},
},
)

View file

@ -0,0 +1,16 @@
<!doctype html>
<html>
<head>
<meta charset="UTF-8" />
<meta
name="viewport"
content="width=device-width, initial-scale=1, maximum-scale=1"
/>
<meta name="color-scheme" content="light dark" />
<title>🦙 llama.cpp - chat</title>
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/main.tsx"></script>
</body>
</html>

6255
tools/server/webui/package-lock.json generated Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,62 @@
{
"name": "webui",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "tsc -b && vite build",
"format": "eslint . && prettier --write .",
"lint": "eslint .",
"preview": "vite preview"
},
"dependencies": {
"@heroicons/react": "^2.2.0",
"@sec-ant/readable-stream": "^0.6.0",
"@tailwindcss/postcss": "^4.1.1",
"@tailwindcss/vite": "^4.1.1",
"@vscode/markdown-it-katex": "^1.1.1",
"autoprefixer": "^10.4.20",
"daisyui": "^5.0.12",
"dexie": "^4.0.11",
"highlight.js": "^11.10.0",
"katex": "^0.16.15",
"postcss": "^8.4.49",
"react": "^18.3.1",
"react-dom": "^18.3.1",
"react-markdown": "^9.0.3",
"react-router": "^7.1.5",
"rehype-highlight": "^7.0.2",
"rehype-katex": "^7.0.1",
"remark-breaks": "^4.0.0",
"remark-gfm": "^4.0.0",
"remark-math": "^6.0.0",
"tailwindcss": "^4.1.1",
"textlinestream": "^1.1.1",
"vite-plugin-singlefile": "^2.0.3"
},
"devDependencies": {
"@eslint/js": "^9.17.0",
"@types/markdown-it": "^14.1.2",
"@types/node": "^22.13.1",
"@types/react": "^18.3.18",
"@types/react-dom": "^18.3.5",
"@vitejs/plugin-react": "^4.3.4",
"eslint": "^9.17.0",
"eslint-plugin-react-hooks": "^5.0.0",
"eslint-plugin-react-refresh": "^0.4.16",
"globals": "^15.14.0",
"prettier": "^3.4.2",
"sass-embedded": "^1.83.4",
"typescript": "~5.6.2",
"typescript-eslint": "^8.18.2",
"vite": "^6.0.5"
},
"prettier": {
"trailingComma": "es5",
"tabWidth": 2,
"semi": true,
"singleQuote": true,
"bracketSameLine": false
}
}

View file

@ -0,0 +1,5 @@
export default {
plugins: {
"@tailwindcss/postcss": {},
},
}

View file

@ -0,0 +1,33 @@
{
"demo": true,
"id": "conv-1734086746930",
"lastModified": 1734087548943,
"messages": [
{
"id": 1734086764521,
"role": "user",
"content": "this is a demo conversation, used in dev mode"
},
{
"id": 1734087548327,
"role": "assistant",
"content": "This is the formula:\n\n$\\frac{e^{x_i}}{\\sum_{j=1}^{n}e^{x_j}}$\n\nGiven an input vector \\(\\mathbf{x} = [x_1, x_2, \\ldots, x_n]\\)\n\n\\[\ny_i = \\frac{e^{x_i}}{\\sum_{j=1}^n e^{x_j}}\n\\]\n\n$2x + y = z$\n\nCode block latex:\n```latex\n\\frac{e^{x_i}}{\\sum_{j=1}^{n}e^{x_j}}\n```\n\nTest dollar sign: $1234 $4567\n\nInvalid latex syntax: $E = mc^$ and $$E = mc^$$",
"timings": {
"prompt_n": 1,
"prompt_ms": 28.923,
"predicted_n": 25,
"predicted_ms": 573.016
}
},
{
"id": 1734087548328,
"role": "user",
"content": "this is a demo conversation, used in dev mode"
},
{
"id": 1734087548329,
"role": "assistant",
"content": "Code block:\n```js\nconsole.log('hello world')\n```\n```sh\nls -la /dev\n```"
}
]
}

View file

@ -0,0 +1,47 @@
import { HashRouter, Outlet, Route, Routes } from 'react-router';
import Header from './components/Header';
import Sidebar from './components/Sidebar';
import { AppContextProvider, useAppContext } from './utils/app.context';
import ChatScreen from './components/ChatScreen';
import SettingDialog from './components/SettingDialog';
function App() {
return (
<HashRouter>
<div className="flex flex-row drawer lg:drawer-open">
<AppContextProvider>
<Routes>
<Route element={<AppLayout />}>
<Route path="/chat/:convId" element={<ChatScreen />} />
<Route path="*" element={<ChatScreen />} />
</Route>
</Routes>
</AppContextProvider>
</div>
</HashRouter>
);
}
function AppLayout() {
const { showSettings, setShowSettings } = useAppContext();
return (
<>
<Sidebar />
<div
className="drawer-content grow flex flex-col h-screen w-screen mx-auto px-4 overflow-auto bg-base-100"
id="main-scroll"
>
<Header />
<Outlet />
</div>
{
<SettingDialog
show={showSettings}
onClose={() => setShowSettings(false)}
/>
}
</>
);
}
export default App;

View file

@ -0,0 +1,92 @@
import daisyuiThemes from 'daisyui/theme/object';
import { isNumeric } from './utils/misc';
export const isDev = import.meta.env.MODE === 'development';
// constants
export const BASE_URL = new URL('.', document.baseURI).href
.toString()
.replace(/\/$/, '');
export const CONFIG_DEFAULT = {
// Note: in order not to introduce breaking changes, please keep the same data type (number, string, etc) if you want to change the default value. Do not use null or undefined for default value.
// Do not use nested objects, keep it single level. Prefix the key if you need to group them.
apiKey: '',
systemMessage: 'You are a helpful assistant.',
showTokensPerSecond: false,
showThoughtInProgress: false,
excludeThoughtOnReq: true,
// make sure these default values are in sync with `common.h`
samplers: 'edkypmxt',
temperature: 0.8,
dynatemp_range: 0.0,
dynatemp_exponent: 1.0,
top_k: 40,
top_p: 0.95,
min_p: 0.05,
xtc_probability: 0.0,
xtc_threshold: 0.1,
typical_p: 1.0,
repeat_last_n: 64,
repeat_penalty: 1.0,
presence_penalty: 0.0,
frequency_penalty: 0.0,
dry_multiplier: 0.0,
dry_base: 1.75,
dry_allowed_length: 2,
dry_penalty_last_n: -1,
max_tokens: -1,
custom: '', // custom json-stringified object
// experimental features
pyIntepreterEnabled: false,
};
export const CONFIG_INFO: Record<string, string> = {
apiKey: 'Set the API Key if you are using --api-key option for the server.',
systemMessage: 'The starting message that defines how model should behave.',
samplers:
'The order at which samplers are applied, in simplified way. Default is "dkypmxt": dry->top_k->typ_p->top_p->min_p->xtc->temperature',
temperature:
'Controls the randomness of the generated text by affecting the probability distribution of the output tokens. Higher = more random, lower = more focused.',
dynatemp_range:
'Addon for the temperature sampler. The added value to the range of dynamic temperature, which adjusts probabilities by entropy of tokens.',
dynatemp_exponent:
'Addon for the temperature sampler. Smoothes out the probability redistribution based on the most probable token.',
top_k: 'Keeps only k top tokens.',
top_p:
'Limits tokens to those that together have a cumulative probability of at least p',
min_p:
'Limits tokens based on the minimum probability for a token to be considered, relative to the probability of the most likely token.',
xtc_probability:
'XTC sampler cuts out top tokens; this parameter controls the chance of cutting tokens at all. 0 disables XTC.',
xtc_threshold:
'XTC sampler cuts out top tokens; this parameter controls the token probability that is required to cut that token.',
typical_p:
'Sorts and limits tokens based on the difference between log-probability and entropy.',
repeat_last_n: 'Last n tokens to consider for penalizing repetition',
repeat_penalty:
'Controls the repetition of token sequences in the generated text',
presence_penalty:
'Limits tokens based on whether they appear in the output or not.',
frequency_penalty:
'Limits tokens based on how often they appear in the output.',
dry_multiplier:
'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the DRY sampling multiplier.',
dry_base:
'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the DRY sampling base value.',
dry_allowed_length:
'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets the allowed length for DRY sampling.',
dry_penalty_last_n:
'DRY sampling reduces repetition in generated text even across long contexts. This parameter sets DRY penalty for the last n tokens.',
max_tokens: 'The maximum number of token per output.',
custom: '', // custom json-stringified object
};
// config keys having numeric value (i.e. temperature, top_k, top_p, etc)
export const CONFIG_NUMERIC_KEYS = Object.entries(CONFIG_DEFAULT)
.filter((e) => isNumeric(e[1]))
.map((e) => e[0]);
// list of themes supported by daisyui
export const THEMES = ['light', 'dark']
// make sure light & dark are always at the beginning
.concat(
Object.keys(daisyuiThemes).filter((t) => t !== 'light' && t !== 'dark')
);

View file

@ -0,0 +1,195 @@
import { useEffect, useState } from 'react';
import { useAppContext } from '../utils/app.context';
import { OpenInNewTab, XCloseButton } from '../utils/common';
import { CanvasType } from '../utils/types';
import { PlayIcon, StopIcon } from '@heroicons/react/24/outline';
import StorageUtils from '../utils/storage';
const canInterrupt = typeof SharedArrayBuffer === 'function';
// adapted from https://pyodide.org/en/stable/usage/webworker.html
const WORKER_CODE = `
importScripts("https://cdn.jsdelivr.net/pyodide/v0.27.2/full/pyodide.js");
let stdOutAndErr = [];
let pyodideReadyPromise = loadPyodide({
stdout: (data) => stdOutAndErr.push(data),
stderr: (data) => stdOutAndErr.push(data),
});
let alreadySetBuff = false;
self.onmessage = async (event) => {
stdOutAndErr = [];
// make sure loading is done
const pyodide = await pyodideReadyPromise;
const { id, python, context, interruptBuffer } = event.data;
if (interruptBuffer && !alreadySetBuff) {
pyodide.setInterruptBuffer(interruptBuffer);
alreadySetBuff = true;
}
// Now load any packages we need, run the code, and send the result back.
await pyodide.loadPackagesFromImports(python);
// make a Python dictionary with the data from content
const dict = pyodide.globals.get("dict");
const globals = dict(Object.entries(context));
try {
self.postMessage({ id, running: true });
// Execute the python code in this context
const result = pyodide.runPython(python, { globals });
self.postMessage({ result, id, stdOutAndErr });
} catch (error) {
self.postMessage({ error: error.message, id });
}
interruptBuffer[0] = 0;
};
`;
let worker: Worker;
const interruptBuffer = canInterrupt
? new Uint8Array(new SharedArrayBuffer(1))
: null;
const startWorker = () => {
if (!worker) {
worker = new Worker(
URL.createObjectURL(new Blob([WORKER_CODE], { type: 'text/javascript' }))
);
}
};
if (StorageUtils.getConfig().pyIntepreterEnabled) {
startWorker();
}
const runCodeInWorker = (
pyCode: string,
callbackRunning: () => void
): {
donePromise: Promise<string>;
interrupt: () => void;
} => {
startWorker();
const id = Math.random() * 1e8;
const context = {};
if (interruptBuffer) {
interruptBuffer[0] = 0;
}
const donePromise = new Promise<string>((resolve) => {
worker.onmessage = (event) => {
const { error, stdOutAndErr, running } = event.data;
if (id !== event.data.id) return;
if (running) {
callbackRunning();
return;
} else if (error) {
resolve(error.toString());
} else {
resolve(stdOutAndErr.join('\n'));
}
};
worker.postMessage({ id, python: pyCode, context, interruptBuffer });
});
const interrupt = () => {
console.log('Interrupting...');
console.trace();
if (interruptBuffer) {
interruptBuffer[0] = 2;
}
};
return { donePromise, interrupt };
};
export default function CanvasPyInterpreter() {
const { canvasData, setCanvasData } = useAppContext();
const [code, setCode] = useState(canvasData?.content ?? ''); // copy to avoid direct mutation
const [running, setRunning] = useState(false);
const [output, setOutput] = useState('');
const [interruptFn, setInterruptFn] = useState<() => void>();
const [showStopBtn, setShowStopBtn] = useState(false);
const runCode = async (pycode: string) => {
interruptFn?.();
setRunning(true);
setOutput('Loading Pyodide...');
const { donePromise, interrupt } = runCodeInWorker(pycode, () => {
setOutput('Running...');
setShowStopBtn(canInterrupt);
});
setInterruptFn(() => interrupt);
const out = await donePromise;
setOutput(out);
setRunning(false);
setShowStopBtn(false);
};
// run code on mount
useEffect(() => {
setCode(canvasData?.content ?? '');
runCode(canvasData?.content ?? '');
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [canvasData?.content]);
if (canvasData?.type !== CanvasType.PY_INTERPRETER) {
return null;
}
return (
<div className="card bg-base-200 w-full h-full shadow-xl">
<div className="card-body">
<div className="flex justify-between items-center mb-4">
<span className="text-lg font-bold">Python Interpreter</span>
<XCloseButton
className="bg-base-100"
onClick={() => setCanvasData(null)}
/>
</div>
<div className="grid grid-rows-3 gap-4 h-full">
<textarea
className="textarea textarea-bordered w-full h-full font-mono"
value={code}
onChange={(e) => setCode(e.target.value)}
></textarea>
<div className="font-mono flex flex-col row-span-2">
<div className="flex items-center mb-2">
<button
className="btn btn-sm bg-base-100"
onClick={() => runCode(code)}
disabled={running}
>
<PlayIcon className="h-6 w-6" /> Run
</button>
{showStopBtn && (
<button
className="btn btn-sm bg-base-100 ml-2"
onClick={() => interruptFn?.()}
>
<StopIcon className="h-6 w-6" /> Stop
</button>
)}
<span className="grow text-right text-xs">
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/issues/11762">
Report a bug
</OpenInNewTab>
</span>
</div>
<textarea
className="textarea textarea-bordered h-full dark-color"
value={output}
readOnly
></textarea>
</div>
</div>
</div>
</div>
);
}

View file

@ -0,0 +1,296 @@
import { useMemo, useState } from 'react';
import { useAppContext } from '../utils/app.context';
import { Message, PendingMessage } from '../utils/types';
import { classNames } from '../utils/misc';
import MarkdownDisplay, { CopyButton } from './MarkdownDisplay';
import { ChevronLeftIcon, ChevronRightIcon } from '@heroicons/react/24/outline';
interface SplitMessage {
content: PendingMessage['content'];
thought?: string;
isThinking?: boolean;
}
export default function ChatMessage({
msg,
siblingLeafNodeIds,
siblingCurrIdx,
id,
onRegenerateMessage,
onEditMessage,
onChangeSibling,
isPending,
}: {
msg: Message | PendingMessage;
siblingLeafNodeIds: Message['id'][];
siblingCurrIdx: number;
id?: string;
onRegenerateMessage(msg: Message): void;
onEditMessage(msg: Message, content: string): void;
onChangeSibling(sibling: Message['id']): void;
isPending?: boolean;
}) {
const { viewingChat, config } = useAppContext();
const [editingContent, setEditingContent] = useState<string | null>(null);
const timings = useMemo(
() =>
msg.timings
? {
...msg.timings,
prompt_per_second:
(msg.timings.prompt_n / msg.timings.prompt_ms) * 1000,
predicted_per_second:
(msg.timings.predicted_n / msg.timings.predicted_ms) * 1000,
}
: null,
[msg.timings]
);
const nextSibling = siblingLeafNodeIds[siblingCurrIdx + 1];
const prevSibling = siblingLeafNodeIds[siblingCurrIdx - 1];
// for reasoning model, we split the message into content and thought
// TODO: implement this as remark/rehype plugin in the future
const { content, thought, isThinking }: SplitMessage = useMemo(() => {
if (msg.content === null || msg.role !== 'assistant') {
return { content: msg.content };
}
let actualContent = '';
let thought = '';
let isThinking = false;
let thinkSplit = msg.content.split('<think>', 2);
actualContent += thinkSplit[0];
while (thinkSplit[1] !== undefined) {
// <think> tag found
thinkSplit = thinkSplit[1].split('</think>', 2);
thought += thinkSplit[0];
isThinking = true;
if (thinkSplit[1] !== undefined) {
// </think> closing tag found
isThinking = false;
thinkSplit = thinkSplit[1].split('<think>', 2);
actualContent += thinkSplit[0];
}
}
return { content: actualContent, thought, isThinking };
}, [msg]);
if (!viewingChat) return null;
return (
<div className="group" id={id}>
<div
className={classNames({
chat: true,
'chat-start': msg.role !== 'user',
'chat-end': msg.role === 'user',
})}
>
<div
className={classNames({
'chat-bubble markdown': true,
'chat-bubble-base-300': msg.role !== 'user',
})}
>
{/* textarea for editing message */}
{editingContent !== null && (
<>
<textarea
dir="auto"
className="textarea textarea-bordered bg-base-100 text-base-content max-w-2xl w-[calc(90vw-8em)] h-24"
value={editingContent}
onChange={(e) => setEditingContent(e.target.value)}
></textarea>
<br />
<button
className="btn btn-ghost mt-2 mr-2"
onClick={() => setEditingContent(null)}
>
Cancel
</button>
<button
className="btn mt-2"
onClick={() => {
if (msg.content !== null) {
setEditingContent(null);
onEditMessage(msg as Message, editingContent);
}
}}
>
Submit
</button>
</>
)}
{/* not editing content, render message */}
{editingContent === null && (
<>
{content === null ? (
<>
{/* show loading dots for pending message */}
<span className="loading loading-dots loading-md"></span>
</>
) : (
<>
{/* render message as markdown */}
<div dir="auto">
{thought && (
<details
className="collapse bg-base-200 collapse-arrow mb-4"
open={isThinking && config.showThoughtInProgress}
>
<summary className="collapse-title">
{isPending && isThinking ? (
<span>
<span
v-if="isGenerating"
className="loading loading-spinner loading-md mr-2"
style={{ verticalAlign: 'middle' }}
></span>
<b>Thinking</b>
</span>
) : (
<b>Thought Process</b>
)}
</summary>
<div className="collapse-content">
<MarkdownDisplay
content={thought}
isGenerating={isPending}
/>
</div>
</details>
)}
{msg.extra && msg.extra.length > 0 && (
<details
className={classNames({
'collapse collapse-arrow mb-4 bg-base-200': true,
'bg-opacity-10': msg.role !== 'assistant',
})}
>
<summary className="collapse-title">
Extra content
</summary>
<div className="collapse-content">
{msg.extra.map(
(extra, i) =>
extra.type === 'textFile' ? (
<div key={extra.name}>
<b>{extra.name}</b>
<pre>{extra.content}</pre>
</div>
) : extra.type === 'context' ? (
<div key={i}>
<pre>{extra.content}</pre>
</div>
) : null // TODO: support other extra types
)}
</div>
</details>
)}
<MarkdownDisplay
content={content}
isGenerating={isPending}
/>
</div>
</>
)}
{/* render timings if enabled */}
{timings && config.showTokensPerSecond && (
<div className="dropdown dropdown-hover dropdown-top mt-2">
<div
tabIndex={0}
role="button"
className="cursor-pointer font-semibold text-sm opacity-60"
>
Speed: {timings.predicted_per_second.toFixed(1)} t/s
</div>
<div className="dropdown-content bg-base-100 z-10 w-64 p-2 shadow mt-4">
<b>Prompt</b>
<br />- Tokens: {timings.prompt_n}
<br />- Time: {timings.prompt_ms} ms
<br />- Speed: {timings.prompt_per_second.toFixed(1)} t/s
<br />
<b>Generation</b>
<br />- Tokens: {timings.predicted_n}
<br />- Time: {timings.predicted_ms} ms
<br />- Speed: {timings.predicted_per_second.toFixed(1)} t/s
<br />
</div>
</div>
)}
</>
)}
</div>
</div>
{/* actions for each message */}
{msg.content !== null && (
<div
className={classNames({
'flex items-center gap-2 mx-4 mt-2 mb-2': true,
'flex-row-reverse': msg.role === 'user',
})}
>
{siblingLeafNodeIds && siblingLeafNodeIds.length > 1 && (
<div className="flex gap-1 items-center opacity-60 text-sm">
<button
className={classNames({
'btn btn-sm btn-ghost p-1': true,
'opacity-20': !prevSibling,
})}
onClick={() => prevSibling && onChangeSibling(prevSibling)}
>
<ChevronLeftIcon className="h-4 w-4" />
</button>
<span>
{siblingCurrIdx + 1} / {siblingLeafNodeIds.length}
</span>
<button
className={classNames({
'btn btn-sm btn-ghost p-1': true,
'opacity-20': !nextSibling,
})}
onClick={() => nextSibling && onChangeSibling(nextSibling)}
>
<ChevronRightIcon className="h-4 w-4" />
</button>
</div>
)}
{/* user message */}
{msg.role === 'user' && (
<button
className="badge btn-mini show-on-hover"
onClick={() => setEditingContent(msg.content)}
disabled={msg.content === null}
>
Edit
</button>
)}
{/* assistant message */}
{msg.role === 'assistant' && (
<>
{!isPending && (
<button
className="badge btn-mini show-on-hover mr-2"
onClick={() => {
if (msg.content !== null) {
onRegenerateMessage(msg as Message);
}
}}
disabled={msg.content === null}
>
🔄 Regenerate
</button>
)}
</>
)}
<CopyButton
className="badge btn-mini show-on-hover mr-2"
content={msg.content}
/>
</div>
)}
</div>
);
}

View file

@ -0,0 +1,296 @@
import { useEffect, useMemo, useState } from 'react';
import { CallbackGeneratedChunk, useAppContext } from '../utils/app.context';
import ChatMessage from './ChatMessage';
import { CanvasType, Message, PendingMessage } from '../utils/types';
import { classNames, cleanCurrentUrl, throttle } from '../utils/misc';
import CanvasPyInterpreter from './CanvasPyInterpreter';
import StorageUtils from '../utils/storage';
import { useVSCodeContext } from '../utils/llama-vscode';
import { useChatTextarea, ChatTextareaApi } from './useChatTextarea.ts';
/**
* A message display is a message node with additional information for rendering.
* For example, siblings of the message node are stored as their last node (aka leaf node).
*/
export interface MessageDisplay {
msg: Message | PendingMessage;
siblingLeafNodeIds: Message['id'][];
siblingCurrIdx: number;
isPending?: boolean;
}
/**
* If the current URL contains "?m=...", prefill the message input with the value.
* If the current URL contains "?q=...", prefill and SEND the message.
*/
const prefilledMsg = {
content() {
const url = new URL(window.location.href);
return url.searchParams.get('m') ?? url.searchParams.get('q') ?? '';
},
shouldSend() {
const url = new URL(window.location.href);
return url.searchParams.has('q');
},
clear() {
cleanCurrentUrl(['m', 'q']);
},
};
function getListMessageDisplay(
msgs: Readonly<Message[]>,
leafNodeId: Message['id']
): MessageDisplay[] {
const currNodes = StorageUtils.filterByLeafNodeId(msgs, leafNodeId, true);
const res: MessageDisplay[] = [];
const nodeMap = new Map<Message['id'], Message>();
for (const msg of msgs) {
nodeMap.set(msg.id, msg);
}
// find leaf node from a message node
const findLeafNode = (msgId: Message['id']): Message['id'] => {
let currNode: Message | undefined = nodeMap.get(msgId);
while (currNode) {
if (currNode.children.length === 0) break;
currNode = nodeMap.get(currNode.children.at(-1) ?? -1);
}
return currNode?.id ?? -1;
};
// traverse the current nodes
for (const msg of currNodes) {
const parentNode = nodeMap.get(msg.parent ?? -1);
if (!parentNode) continue;
const siblings = parentNode.children;
if (msg.type !== 'root') {
res.push({
msg,
siblingLeafNodeIds: siblings.map(findLeafNode),
siblingCurrIdx: siblings.indexOf(msg.id),
});
}
}
return res;
}
const scrollToBottom = throttle(
(requiresNearBottom: boolean, delay: number = 80) => {
const mainScrollElem = document.getElementById('main-scroll');
if (!mainScrollElem) return;
const spaceToBottom =
mainScrollElem.scrollHeight -
mainScrollElem.scrollTop -
mainScrollElem.clientHeight;
if (!requiresNearBottom || spaceToBottom < 50) {
setTimeout(
() => mainScrollElem.scrollTo({ top: mainScrollElem.scrollHeight }),
delay
);
}
},
80
);
export default function ChatScreen() {
const {
viewingChat,
sendMessage,
isGenerating,
stopGenerating,
pendingMessages,
canvasData,
replaceMessageAndGenerate,
} = useAppContext();
const textarea: ChatTextareaApi = useChatTextarea(prefilledMsg.content());
const { extraContext, clearExtraContext } = useVSCodeContext(textarea);
// TODO: improve this when we have "upload file" feature
const currExtra: Message['extra'] = extraContext ? [extraContext] : undefined;
// keep track of leaf node for rendering
const [currNodeId, setCurrNodeId] = useState<number>(-1);
const messages: MessageDisplay[] = useMemo(() => {
if (!viewingChat) return [];
else return getListMessageDisplay(viewingChat.messages, currNodeId);
}, [currNodeId, viewingChat]);
const currConvId = viewingChat?.conv.id ?? null;
const pendingMsg: PendingMessage | undefined =
pendingMessages[currConvId ?? ''];
useEffect(() => {
// reset to latest node when conversation changes
setCurrNodeId(-1);
// scroll to bottom when conversation changes
scrollToBottom(false, 1);
}, [currConvId]);
const onChunk: CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => {
if (currLeafNodeId) {
setCurrNodeId(currLeafNodeId);
}
scrollToBottom(true);
};
const sendNewMessage = async () => {
const lastInpMsg = textarea.value();
if (lastInpMsg.trim().length === 0 || isGenerating(currConvId ?? ''))
return;
textarea.setValue('');
scrollToBottom(false);
setCurrNodeId(-1);
// get the last message node
const lastMsgNodeId = messages.at(-1)?.msg.id ?? null;
if (
!(await sendMessage(
currConvId,
lastMsgNodeId,
lastInpMsg,
currExtra,
onChunk
))
) {
// restore the input message if failed
textarea.setValue(lastInpMsg);
}
// OK
clearExtraContext();
};
const handleEditMessage = async (msg: Message, content: string) => {
if (!viewingChat) return;
setCurrNodeId(msg.id);
scrollToBottom(false);
await replaceMessageAndGenerate(
viewingChat.conv.id,
msg.parent,
content,
msg.extra,
onChunk
);
setCurrNodeId(-1);
scrollToBottom(false);
};
const handleRegenerateMessage = async (msg: Message) => {
if (!viewingChat) return;
setCurrNodeId(msg.parent);
scrollToBottom(false);
await replaceMessageAndGenerate(
viewingChat.conv.id,
msg.parent,
null,
msg.extra,
onChunk
);
setCurrNodeId(-1);
scrollToBottom(false);
};
const hasCanvas = !!canvasData;
useEffect(() => {
if (prefilledMsg.shouldSend()) {
// send the prefilled message if needed
sendNewMessage();
} else {
// otherwise, focus on the input
textarea.focus();
}
prefilledMsg.clear();
// no need to keep track of sendNewMessage
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [textarea.ref]);
// due to some timing issues of StorageUtils.appendMsg(), we need to make sure the pendingMsg is not duplicated upon rendering (i.e. appears once in the saved conversation and once in the pendingMsg)
const pendingMsgDisplay: MessageDisplay[] =
pendingMsg && messages.at(-1)?.msg.id !== pendingMsg.id
? [
{
msg: pendingMsg,
siblingLeafNodeIds: [],
siblingCurrIdx: 0,
isPending: true,
},
]
: [];
return (
<div
className={classNames({
'grid lg:gap-8 grow transition-[300ms]': true,
'grid-cols-[1fr_0fr] lg:grid-cols-[1fr_1fr]': hasCanvas, // adapted for mobile
'grid-cols-[1fr_0fr]': !hasCanvas,
})}
>
<div
className={classNames({
'flex flex-col w-full max-w-[900px] mx-auto': true,
'hidden lg:flex': hasCanvas, // adapted for mobile
flex: !hasCanvas,
})}
>
{/* chat messages */}
<div id="messages-list" className="grow">
<div className="mt-auto flex justify-center">
{/* placeholder to shift the message to the bottom */}
{viewingChat ? '' : 'Send a message to start'}
</div>
{[...messages, ...pendingMsgDisplay].map((msg) => (
<ChatMessage
key={msg.msg.id}
msg={msg.msg}
siblingLeafNodeIds={msg.siblingLeafNodeIds}
siblingCurrIdx={msg.siblingCurrIdx}
onRegenerateMessage={handleRegenerateMessage}
onEditMessage={handleEditMessage}
onChangeSibling={setCurrNodeId}
/>
))}
</div>
{/* chat input */}
<div className="flex flex-row items-end pt-8 pb-6 sticky bottom-0 bg-base-100">
<textarea
// Default (mobile): Enable vertical resize, overflow auto for scrolling if needed
// Large screens (lg:): Disable manual resize, apply max-height for autosize limit
className="textarea textarea-bordered w-full resize-vertical lg:resize-none lg:max-h-48 lg:overflow-y-auto" // Adjust lg:max-h-48 as needed (e.g., lg:max-h-60)
placeholder="Type a message (Shift+Enter to add a new line)"
ref={textarea.ref}
onInput={textarea.onInput} // Hook's input handler (will only resize height on lg+ screens)
onKeyDown={(e) => {
if (e.nativeEvent.isComposing || e.keyCode === 229) return;
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendNewMessage();
}
}}
id="msg-input"
dir="auto"
// Set a base height of 2 rows for mobile views
// On lg+ screens, the hook will calculate and set the initial height anyway
rows={2}
></textarea>
{isGenerating(currConvId ?? '') ? (
<button
className="btn btn-neutral ml-2"
onClick={() => stopGenerating(currConvId ?? '')}
>
Stop
</button>
) : (
<button className="btn btn-primary ml-2" onClick={sendNewMessage}>
Send
</button>
)}
</div>
</div>
<div className="w-full sticky top-[7em] h-[calc(100vh-9em)]">
{canvasData?.type === CanvasType.PY_INTERPRETER && (
<CanvasPyInterpreter />
)}
</div>
</div>
);
}

View file

@ -0,0 +1,178 @@
import { useEffect, useState } from 'react';
import StorageUtils from '../utils/storage';
import { useAppContext } from '../utils/app.context';
import { classNames } from '../utils/misc';
import daisyuiThemes from 'daisyui/theme/object';
import { THEMES } from '../Config';
import { useNavigate } from 'react-router';
export default function Header() {
const navigate = useNavigate();
const [selectedTheme, setSelectedTheme] = useState(StorageUtils.getTheme());
const { setShowSettings } = useAppContext();
const setTheme = (theme: string) => {
StorageUtils.setTheme(theme);
setSelectedTheme(theme);
};
useEffect(() => {
document.body.setAttribute('data-theme', selectedTheme);
document.body.setAttribute(
'data-color-scheme',
daisyuiThemes[selectedTheme]?.['color-scheme'] ?? 'auto'
);
}, [selectedTheme]);
const { isGenerating, viewingChat } = useAppContext();
const isCurrConvGenerating = isGenerating(viewingChat?.conv.id ?? '');
const removeConversation = () => {
if (isCurrConvGenerating || !viewingChat) return;
const convId = viewingChat?.conv.id;
if (window.confirm('Are you sure to delete this conversation?')) {
StorageUtils.remove(convId);
navigate('/');
}
};
const downloadConversation = () => {
if (isCurrConvGenerating || !viewingChat) return;
const convId = viewingChat?.conv.id;
const conversationJson = JSON.stringify(viewingChat, null, 2);
const blob = new Blob([conversationJson], { type: 'application/json' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = `conversation_${convId}.json`;
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
};
return (
<div className="flex flex-row items-center pt-6 pb-6 sticky top-0 z-10 bg-base-100">
{/* open sidebar button */}
<label htmlFor="toggle-drawer" className="btn btn-ghost lg:hidden">
<svg
xmlns="http://www.w3.org/2000/svg"
width="16"
height="16"
fill="currentColor"
className="bi bi-list"
viewBox="0 0 16 16"
>
<path
fillRule="evenodd"
d="M2.5 12a.5.5 0 0 1 .5-.5h10a.5.5 0 0 1 0 1H3a.5.5 0 0 1-.5-.5m0-4a.5.5 0 0 1 .5-.5h10a.5.5 0 0 1 0 1H3a.5.5 0 0 1-.5-.5m0-4a.5.5 0 0 1 .5-.5h10a.5.5 0 0 1 0 1H3a.5.5 0 0 1-.5-.5"
/>
</svg>
</label>
<div className="grow text-2xl font-bold ml-2">llama.cpp</div>
{/* action buttons (top right) */}
<div className="flex items-center">
{viewingChat && (
<div className="dropdown dropdown-end">
{/* "..." button */}
<button
tabIndex={0}
role="button"
className="btn m-1"
disabled={isCurrConvGenerating}
>
<svg
xmlns="http://www.w3.org/2000/svg"
width="16"
height="16"
fill="currentColor"
className="bi bi-three-dots-vertical"
viewBox="0 0 16 16"
>
<path d="M9.5 13a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0m0-5a1.5 1.5 0 1 1-3 0 1.5 1.5 0 0 1 3 0" />
</svg>
</button>
{/* dropdown menu */}
<ul
tabIndex={0}
className="dropdown-content menu bg-base-100 rounded-box z-[1] w-52 p-2 shadow"
>
<li onClick={downloadConversation}>
<a>Download</a>
</li>
<li className="text-error" onClick={removeConversation}>
<a>Delete</a>
</li>
</ul>
</div>
)}
<div className="tooltip tooltip-bottom" data-tip="Settings">
<button className="btn" onClick={() => setShowSettings(true)}>
{/* settings button */}
<svg
xmlns="http://www.w3.org/2000/svg"
width="16"
height="16"
fill="currentColor"
className="bi bi-gear"
viewBox="0 0 16 16"
>
<path d="M8 4.754a3.246 3.246 0 1 0 0 6.492 3.246 3.246 0 0 0 0-6.492M5.754 8a2.246 2.246 0 1 1 4.492 0 2.246 2.246 0 0 1-4.492 0" />
<path d="M9.796 1.343c-.527-1.79-3.065-1.79-3.592 0l-.094.319a.873.873 0 0 1-1.255.52l-.292-.16c-1.64-.892-3.433.902-2.54 2.541l.159.292a.873.873 0 0 1-.52 1.255l-.319.094c-1.79.527-1.79 3.065 0 3.592l.319.094a.873.873 0 0 1 .52 1.255l-.16.292c-.892 1.64.901 3.434 2.541 2.54l.292-.159a.873.873 0 0 1 1.255.52l.094.319c.527 1.79 3.065 1.79 3.592 0l.094-.319a.873.873 0 0 1 1.255-.52l.292.16c1.64.893 3.434-.902 2.54-2.541l-.159-.292a.873.873 0 0 1 .52-1.255l.319-.094c1.79-.527 1.79-3.065 0-3.592l-.319-.094a.873.873 0 0 1-.52-1.255l.16-.292c.893-1.64-.902-3.433-2.541-2.54l-.292.159a.873.873 0 0 1-1.255-.52zm-2.633.283c.246-.835 1.428-.835 1.674 0l.094.319a1.873 1.873 0 0 0 2.693 1.115l.291-.16c.764-.415 1.6.42 1.184 1.185l-.159.292a1.873 1.873 0 0 0 1.116 2.692l.318.094c.835.246.835 1.428 0 1.674l-.319.094a1.873 1.873 0 0 0-1.115 2.693l.16.291c.415.764-.42 1.6-1.185 1.184l-.291-.159a1.873 1.873 0 0 0-2.693 1.116l-.094.318c-.246.835-1.428.835-1.674 0l-.094-.319a1.873 1.873 0 0 0-2.692-1.115l-.292.16c-.764.415-1.6-.42-1.184-1.185l.159-.291A1.873 1.873 0 0 0 1.945 8.93l-.319-.094c-.835-.246-.835-1.428 0-1.674l.319-.094A1.873 1.873 0 0 0 3.06 4.377l-.16-.292c-.415-.764.42-1.6 1.185-1.184l.292.159a1.873 1.873 0 0 0 2.692-1.115z" />
</svg>
</button>
</div>
{/* theme controller is copied from https://daisyui.com/components/theme-controller/ */}
<div className="tooltip tooltip-bottom" data-tip="Themes">
<div className="dropdown dropdown-end dropdown-bottom">
<div tabIndex={0} role="button" className="btn m-1">
<svg
xmlns="http://www.w3.org/2000/svg"
width="16"
height="16"
fill="currentColor"
className="bi bi-palette2"
viewBox="0 0 16 16"
>
<path d="M0 .5A.5.5 0 0 1 .5 0h5a.5.5 0 0 1 .5.5v5.277l4.147-4.131a.5.5 0 0 1 .707 0l3.535 3.536a.5.5 0 0 1 0 .708L10.261 10H15.5a.5.5 0 0 1 .5.5v5a.5.5 0 0 1-.5.5H3a3 3 0 0 1-2.121-.879A3 3 0 0 1 0 13.044m6-.21 7.328-7.3-2.829-2.828L6 7.188zM4.5 13a1.5 1.5 0 1 0-3 0 1.5 1.5 0 0 0 3 0M15 15v-4H9.258l-4.015 4zM0 .5v12.495zm0 12.495V13z" />
</svg>
</div>
<ul
tabIndex={0}
className="dropdown-content bg-base-300 rounded-box z-[1] w-52 p-2 shadow-2xl h-80 overflow-y-auto"
>
<li>
<button
className={classNames({
'btn btn-sm btn-block btn-ghost justify-start': true,
'btn-active': selectedTheme === 'auto',
})}
onClick={() => setTheme('auto')}
>
auto
</button>
</li>
{THEMES.map((theme) => (
<li key={theme}>
<input
type="radio"
name="theme-dropdown"
className="theme-controller btn btn-sm btn-block btn-ghost justify-start"
aria-label={theme}
value={theme}
checked={selectedTheme === theme}
onChange={(e) => e.target.checked && setTheme(theme)}
/>
</li>
))}
</ul>
</div>
</div>
</div>
</div>
);
}

View file

@ -0,0 +1,310 @@
import React, { useMemo, useState } from 'react';
import Markdown, { ExtraProps } from 'react-markdown';
import remarkGfm from 'remark-gfm';
import rehypeHightlight from 'rehype-highlight';
import rehypeKatex from 'rehype-katex';
import remarkMath from 'remark-math';
import remarkBreaks from 'remark-breaks';
import 'katex/dist/katex.min.css';
import { classNames, copyStr } from '../utils/misc';
import { ElementContent, Root } from 'hast';
import { visit } from 'unist-util-visit';
import { useAppContext } from '../utils/app.context';
import { CanvasType } from '../utils/types';
export default function MarkdownDisplay({
content,
isGenerating,
}: {
content: string;
isGenerating?: boolean;
}) {
const preprocessedContent = useMemo(
() => preprocessLaTeX(content),
[content]
);
return (
<Markdown
remarkPlugins={[remarkGfm, remarkMath, remarkBreaks]}
rehypePlugins={[rehypeHightlight, rehypeKatex, rehypeCustomCopyButton]}
components={{
button: (props) => (
<CodeBlockButtons
{...props}
isGenerating={isGenerating}
origContent={preprocessedContent}
/>
),
// note: do not use "pre", "p" or other basic html elements here, it will cause the node to re-render when the message is being generated (this should be a bug with react-markdown, not sure how to fix it)
}}
>
{preprocessedContent}
</Markdown>
);
}
const CodeBlockButtons: React.ElementType<
React.ClassAttributes<HTMLButtonElement> &
React.HTMLAttributes<HTMLButtonElement> &
ExtraProps & { origContent: string; isGenerating?: boolean }
> = ({ node, origContent, isGenerating }) => {
const { config } = useAppContext();
const startOffset = node?.position?.start.offset ?? 0;
const endOffset = node?.position?.end.offset ?? 0;
const copiedContent = useMemo(
() =>
origContent
.substring(startOffset, endOffset)
.replace(/^```[^\n]+\n/g, '')
.replace(/```$/g, ''),
[origContent, startOffset, endOffset]
);
const codeLanguage = useMemo(
() =>
origContent
.substring(startOffset, startOffset + 10)
.match(/^```([^\n]+)\n/)?.[1] ?? '',
[origContent, startOffset]
);
const canRunCode =
!isGenerating &&
config.pyIntepreterEnabled &&
codeLanguage.startsWith('py');
return (
<div
className={classNames({
'text-right sticky top-[7em] mb-2 mr-2 h-0': true,
'display-none': !node?.position,
})}
>
<CopyButton className="badge btn-mini" content={copiedContent} />
{canRunCode && (
<RunPyCodeButton
className="badge btn-mini ml-2"
content={copiedContent}
/>
)}
</div>
);
};
export const CopyButton = ({
content,
className,
}: {
content: string;
className?: string;
}) => {
const [copied, setCopied] = useState(false);
return (
<button
className={className}
onClick={() => {
copyStr(content);
setCopied(true);
}}
onMouseLeave={() => setCopied(false)}
>
{copied ? 'Copied!' : '📋 Copy'}
</button>
);
};
export const RunPyCodeButton = ({
content,
className,
}: {
content: string;
className?: string;
}) => {
const { setCanvasData } = useAppContext();
return (
<>
<button
className={className}
onClick={() =>
setCanvasData({
type: CanvasType.PY_INTERPRETER,
content,
})
}
>
Run
</button>
</>
);
};
/**
* This injects the "button" element before each "pre" element.
* The actual button will be replaced with a react component in the MarkdownDisplay.
* We don't replace "pre" node directly because it will cause the node to re-render, which causes this bug: https://github.com/ggerganov/llama.cpp/issues/9608
*/
function rehypeCustomCopyButton() {
return function (tree: Root) {
visit(tree, 'element', function (node) {
if (node.tagName === 'pre' && !node.properties.visited) {
const preNode = { ...node };
// replace current node
preNode.properties.visited = 'true';
node.tagName = 'div';
node.properties = {};
// add node for button
const btnNode: ElementContent = {
type: 'element',
tagName: 'button',
properties: {},
children: [],
position: node.position,
};
node.children = [btnNode, preNode];
}
});
};
}
/**
* The part below is copied and adapted from:
* https://github.com/danny-avila/LibreChat/blob/main/client/src/utils/latex.ts
* (MIT License)
*/
// Regex to check if the processed content contains any potential LaTeX patterns
const containsLatexRegex =
/\\\(.*?\\\)|\\\[.*?\\\]|\$.*?\$|\\begin\{equation\}.*?\\end\{equation\}/;
// Regex for inline and block LaTeX expressions
const inlineLatex = new RegExp(/\\\((.+?)\\\)/, 'g');
const blockLatex = new RegExp(/\\\[(.*?[^\\])\\\]/, 'gs');
// Function to restore code blocks
const restoreCodeBlocks = (content: string, codeBlocks: string[]) => {
return content.replace(
/<<CODE_BLOCK_(\d+)>>/g,
(_, index) => codeBlocks[index]
);
};
// Regex to identify code blocks and inline code
const codeBlockRegex = /(```[\s\S]*?```|`.*?`)/g;
export const processLaTeX = (_content: string) => {
let content = _content;
// Temporarily replace code blocks and inline code with placeholders
const codeBlocks: string[] = [];
let index = 0;
content = content.replace(codeBlockRegex, (match) => {
codeBlocks[index] = match;
return `<<CODE_BLOCK_${index++}>>`;
});
// Escape dollar signs followed by a digit or space and digit
let processedContent = content.replace(/(\$)(?=\s?\d)/g, '\\$');
// If no LaTeX patterns are found, restore code blocks and return the processed content
if (!containsLatexRegex.test(processedContent)) {
return restoreCodeBlocks(processedContent, codeBlocks);
}
// Convert LaTeX expressions to a markdown compatible format
processedContent = processedContent
.replace(inlineLatex, (_: string, equation: string) => `$${equation}$`) // Convert inline LaTeX
.replace(blockLatex, (_: string, equation: string) => `$$${equation}$$`); // Convert block LaTeX
// Restore code blocks
return restoreCodeBlocks(processedContent, codeBlocks);
};
/**
* Preprocesses LaTeX content by replacing delimiters and escaping certain characters.
*
* @param content The input string containing LaTeX expressions.
* @returns The processed string with replaced delimiters and escaped characters.
*/
export function preprocessLaTeX(content: string): string {
// Step 1: Protect code blocks
const codeBlocks: string[] = [];
content = content.replace(/(```[\s\S]*?```|`[^`\n]+`)/g, (_, code) => {
codeBlocks.push(code);
return `<<CODE_BLOCK_${codeBlocks.length - 1}>>`;
});
// Step 2: Protect existing LaTeX expressions
const latexExpressions: string[] = [];
// Protect block math ($$...$$), \[...\], and \(...\) as before.
content = content.replace(
/(\$\$[\s\S]*?\$\$|\\\[[\s\S]*?\\\]|\\\(.*?\\\))/g,
(match) => {
latexExpressions.push(match);
return `<<LATEX_${latexExpressions.length - 1}>>`;
}
);
// Protect inline math ($...$) only if it does NOT match a currency pattern.
// We assume a currency pattern is one where the inner content is purely numeric (with optional decimals).
content = content.replace(/\$([^$]+)\$/g, (match, inner) => {
if (/^\s*\d+(?:\.\d+)?\s*$/.test(inner)) {
// This looks like a currency value (e.g. "$123" or "$12.34"),
// so don't protect it.
return match;
} else {
// Otherwise, treat it as a LaTeX expression.
latexExpressions.push(match);
return `<<LATEX_${latexExpressions.length - 1}>>`;
}
});
// Step 3: Escape dollar signs that are likely currency indicators.
// (Now that inline math is protected, this will only escape dollars not already protected)
content = content.replace(/\$(?=\d)/g, '\\$');
// Step 4: Restore LaTeX expressions
content = content.replace(
/<<LATEX_(\d+)>>/g,
(_, index) => latexExpressions[parseInt(index)]
);
// Step 5: Restore code blocks
content = content.replace(
/<<CODE_BLOCK_(\d+)>>/g,
(_, index) => codeBlocks[parseInt(index)]
);
// Step 6: Apply additional escaping functions
content = escapeBrackets(content);
content = escapeMhchem(content);
return content;
}
export function escapeBrackets(text: string): string {
const pattern =
/(```[\S\s]*?```|`.*?`)|\\\[([\S\s]*?[^\\])\\]|\\\((.*?)\\\)/g;
return text.replace(
pattern,
(
match: string,
codeBlock: string | undefined,
squareBracket: string | undefined,
roundBracket: string | undefined
): string => {
if (codeBlock != null) {
return codeBlock;
} else if (squareBracket != null) {
return `$$${squareBracket}$$`;
} else if (roundBracket != null) {
return `$${roundBracket}$`;
}
return match;
}
);
}
export function escapeMhchem(text: string) {
return text.replaceAll('$\\ce{', '$\\\\ce{').replaceAll('$\\pu{', '$\\\\pu{');
}

View file

@ -0,0 +1,536 @@
import { useState } from 'react';
import { useAppContext } from '../utils/app.context';
import { CONFIG_DEFAULT, CONFIG_INFO } from '../Config';
import { isDev } from '../Config';
import StorageUtils from '../utils/storage';
import { classNames, isBoolean, isNumeric, isString } from '../utils/misc';
import {
BeakerIcon,
ChatBubbleOvalLeftEllipsisIcon,
Cog6ToothIcon,
FunnelIcon,
HandRaisedIcon,
SquaresPlusIcon,
} from '@heroicons/react/24/outline';
import { OpenInNewTab } from '../utils/common';
type SettKey = keyof typeof CONFIG_DEFAULT;
const BASIC_KEYS: SettKey[] = [
'temperature',
'top_k',
'top_p',
'min_p',
'max_tokens',
];
const SAMPLER_KEYS: SettKey[] = [
'dynatemp_range',
'dynatemp_exponent',
'typical_p',
'xtc_probability',
'xtc_threshold',
];
const PENALTY_KEYS: SettKey[] = [
'repeat_last_n',
'repeat_penalty',
'presence_penalty',
'frequency_penalty',
'dry_multiplier',
'dry_base',
'dry_allowed_length',
'dry_penalty_last_n',
];
enum SettingInputType {
SHORT_INPUT,
LONG_INPUT,
CHECKBOX,
CUSTOM,
}
interface SettingFieldInput {
type: Exclude<SettingInputType, SettingInputType.CUSTOM>;
label: string | React.ReactElement;
help?: string | React.ReactElement;
key: SettKey;
}
interface SettingFieldCustom {
type: SettingInputType.CUSTOM;
key: SettKey;
component:
| string
| React.FC<{
value: string | boolean | number;
onChange: (value: string) => void;
}>;
}
interface SettingSection {
title: React.ReactElement;
fields: (SettingFieldInput | SettingFieldCustom)[];
}
const ICON_CLASSNAME = 'w-4 h-4 mr-1 inline';
const SETTING_SECTIONS: SettingSection[] = [
{
title: (
<>
<Cog6ToothIcon className={ICON_CLASSNAME} />
General
</>
),
fields: [
{
type: SettingInputType.SHORT_INPUT,
label: 'API Key',
key: 'apiKey',
},
{
type: SettingInputType.LONG_INPUT,
label: 'System Message (will be disabled if left empty)',
key: 'systemMessage',
},
...BASIC_KEYS.map(
(key) =>
({
type: SettingInputType.SHORT_INPUT,
label: key,
key,
}) as SettingFieldInput
),
],
},
{
title: (
<>
<FunnelIcon className={ICON_CLASSNAME} />
Samplers
</>
),
fields: [
{
type: SettingInputType.SHORT_INPUT,
label: 'Samplers queue',
key: 'samplers',
},
...SAMPLER_KEYS.map(
(key) =>
({
type: SettingInputType.SHORT_INPUT,
label: key,
key,
}) as SettingFieldInput
),
],
},
{
title: (
<>
<HandRaisedIcon className={ICON_CLASSNAME} />
Penalties
</>
),
fields: PENALTY_KEYS.map((key) => ({
type: SettingInputType.SHORT_INPUT,
label: key,
key,
})),
},
{
title: (
<>
<ChatBubbleOvalLeftEllipsisIcon className={ICON_CLASSNAME} />
Reasoning
</>
),
fields: [
{
type: SettingInputType.CHECKBOX,
label: 'Expand thought process by default when generating messages',
key: 'showThoughtInProgress',
},
{
type: SettingInputType.CHECKBOX,
label:
'Exclude thought process when sending requests to API (Recommended for DeepSeek-R1)',
key: 'excludeThoughtOnReq',
},
],
},
{
title: (
<>
<SquaresPlusIcon className={ICON_CLASSNAME} />
Advanced
</>
),
fields: [
{
type: SettingInputType.CUSTOM,
key: 'custom', // dummy key, won't be used
component: () => {
const debugImportDemoConv = async () => {
const res = await fetch('/demo-conversation.json');
const demoConv = await res.json();
StorageUtils.remove(demoConv.id);
for (const msg of demoConv.messages) {
StorageUtils.appendMsg(demoConv.id, msg);
}
};
return (
<button className="btn" onClick={debugImportDemoConv}>
(debug) Import demo conversation
</button>
);
},
},
{
type: SettingInputType.CHECKBOX,
label: 'Show tokens per second',
key: 'showTokensPerSecond',
},
{
type: SettingInputType.LONG_INPUT,
label: (
<>
Custom JSON config (For more info, refer to{' '}
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/blob/master/tools/server/README.md">
server documentation
</OpenInNewTab>
)
</>
),
key: 'custom',
},
],
},
{
title: (
<>
<BeakerIcon className={ICON_CLASSNAME} />
Experimental
</>
),
fields: [
{
type: SettingInputType.CUSTOM,
key: 'custom', // dummy key, won't be used
component: () => (
<>
<p className="mb-8">
Experimental features are not guaranteed to work correctly.
<br />
<br />
If you encounter any problems, create a{' '}
<OpenInNewTab href="https://github.com/ggerganov/llama.cpp/issues/new?template=019-bug-misc.yml">
Bug (misc.)
</OpenInNewTab>{' '}
report on Github. Please also specify <b>webui/experimental</b> on
the report title and include screenshots.
<br />
<br />
Some features may require packages downloaded from CDN, so they
need internet connection.
</p>
</>
),
},
{
type: SettingInputType.CHECKBOX,
label: (
<>
<b>Enable Python interpreter</b>
<br />
<small className="text-xs">
This feature uses{' '}
<OpenInNewTab href="https://pyodide.org">pyodide</OpenInNewTab>,
downloaded from CDN. To use this feature, ask the LLM to generate
Python code inside a Markdown code block. You will see a "Run"
button on the code block, near the "Copy" button.
</small>
</>
),
key: 'pyIntepreterEnabled',
},
],
},
];
export default function SettingDialog({
show,
onClose,
}: {
show: boolean;
onClose: () => void;
}) {
const { config, saveConfig } = useAppContext();
const [sectionIdx, setSectionIdx] = useState(0);
// clone the config object to prevent direct mutation
const [localConfig, setLocalConfig] = useState<typeof CONFIG_DEFAULT>(
JSON.parse(JSON.stringify(config))
);
const resetConfig = () => {
if (window.confirm('Are you sure you want to reset all settings?')) {
setLocalConfig(CONFIG_DEFAULT);
}
};
const handleSave = () => {
// copy the local config to prevent direct mutation
const newConfig: typeof CONFIG_DEFAULT = JSON.parse(
JSON.stringify(localConfig)
);
// validate the config
for (const key in newConfig) {
const value = newConfig[key as SettKey];
const mustBeBoolean = isBoolean(CONFIG_DEFAULT[key as SettKey]);
const mustBeString = isString(CONFIG_DEFAULT[key as SettKey]);
const mustBeNumeric = isNumeric(CONFIG_DEFAULT[key as SettKey]);
if (mustBeString) {
if (!isString(value)) {
alert(`Value for ${key} must be string`);
return;
}
} else if (mustBeNumeric) {
const trimmedValue = value.toString().trim();
const numVal = Number(trimmedValue);
if (isNaN(numVal) || !isNumeric(numVal) || trimmedValue.length === 0) {
alert(`Value for ${key} must be numeric`);
return;
}
// force conversion to number
// @ts-expect-error this is safe
newConfig[key] = numVal;
} else if (mustBeBoolean) {
if (!isBoolean(value)) {
alert(`Value for ${key} must be boolean`);
return;
}
} else {
console.error(`Unknown default type for key ${key}`);
}
}
if (isDev) console.log('Saving config', newConfig);
saveConfig(newConfig);
onClose();
};
const onChange = (key: SettKey) => (value: string | boolean) => {
// note: we do not perform validation here, because we may get incomplete value as user is still typing it
setLocalConfig({ ...localConfig, [key]: value });
};
return (
<dialog className={classNames({ modal: true, 'modal-open': show })}>
<div className="modal-box w-11/12 max-w-3xl">
<h3 className="text-lg font-bold mb-6">Settings</h3>
<div className="flex flex-col md:flex-row h-[calc(90vh-12rem)]">
{/* Left panel, showing sections - Desktop version */}
<div className="hidden md:flex flex-col items-stretch pr-4 mr-4 border-r-2 border-base-200">
{SETTING_SECTIONS.map((section, idx) => (
<div
key={idx}
className={classNames({
'btn btn-ghost justify-start font-normal w-44 mb-1': true,
'btn-active': sectionIdx === idx,
})}
onClick={() => setSectionIdx(idx)}
dir="auto"
>
{section.title}
</div>
))}
</div>
{/* Left panel, showing sections - Mobile version */}
<div className="md:hidden flex flex-row gap-2 mb-4">
<details className="dropdown">
<summary className="btn bt-sm w-full m-1">
{SETTING_SECTIONS[sectionIdx].title}
</summary>
<ul className="menu dropdown-content bg-base-100 rounded-box z-[1] w-52 p-2 shadow">
{SETTING_SECTIONS.map((section, idx) => (
<div
key={idx}
className={classNames({
'btn btn-ghost justify-start font-normal': true,
'btn-active': sectionIdx === idx,
})}
onClick={() => setSectionIdx(idx)}
dir="auto"
>
{section.title}
</div>
))}
</ul>
</details>
</div>
{/* Right panel, showing setting fields */}
<div className="grow overflow-y-auto px-4">
{SETTING_SECTIONS[sectionIdx].fields.map((field, idx) => {
const key = `${sectionIdx}-${idx}`;
if (field.type === SettingInputType.SHORT_INPUT) {
return (
<SettingsModalShortInput
key={key}
configKey={field.key}
value={localConfig[field.key]}
onChange={onChange(field.key)}
label={field.label as string}
/>
);
} else if (field.type === SettingInputType.LONG_INPUT) {
return (
<SettingsModalLongInput
key={key}
configKey={field.key}
value={localConfig[field.key].toString()}
onChange={onChange(field.key)}
label={field.label as string}
/>
);
} else if (field.type === SettingInputType.CHECKBOX) {
return (
<SettingsModalCheckbox
key={key}
configKey={field.key}
value={!!localConfig[field.key]}
onChange={onChange(field.key)}
label={field.label as string}
/>
);
} else if (field.type === SettingInputType.CUSTOM) {
return (
<div key={key} className="mb-2">
{typeof field.component === 'string'
? field.component
: field.component({
value: localConfig[field.key],
onChange: onChange(field.key),
})}
</div>
);
}
})}
<p className="opacity-40 mb-6 text-sm mt-8">
Settings are saved in browser's localStorage
</p>
</div>
</div>
<div className="modal-action">
<button className="btn" onClick={resetConfig}>
Reset to default
</button>
<button className="btn" onClick={onClose}>
Close
</button>
<button className="btn btn-primary" onClick={handleSave}>
Save
</button>
</div>
</div>
</dialog>
);
}
function SettingsModalLongInput({
configKey,
value,
onChange,
label,
}: {
configKey: SettKey;
value: string;
onChange: (value: string) => void;
label?: string;
}) {
return (
<label className="form-control mb-2">
<div className="label inline">{label || configKey}</div>
<textarea
className="textarea textarea-bordered h-24"
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
value={value}
onChange={(e) => onChange(e.target.value)}
/>
</label>
);
}
function SettingsModalShortInput({
configKey,
value,
onChange,
label,
}: {
configKey: SettKey;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
value: any;
onChange: (value: string) => void;
label?: string;
}) {
const helpMsg = CONFIG_INFO[configKey];
return (
<>
{/* on mobile, we simply show the help message here */}
{helpMsg && (
<div className="block md:hidden mb-1">
<b>{label || configKey}</b>
<br />
<p className="text-xs">{helpMsg}</p>
</div>
)}
<label className="input input-bordered join-item grow flex items-center gap-2 mb-2">
<div className="dropdown dropdown-hover">
<div tabIndex={0} role="button" className="font-bold hidden md:block">
{label || configKey}
</div>
{helpMsg && (
<div className="dropdown-content menu bg-base-100 rounded-box z-10 w-64 p-2 shadow mt-4">
{helpMsg}
</div>
)}
</div>
<input
type="text"
className="grow"
placeholder={`Default: ${CONFIG_DEFAULT[configKey] || 'none'}`}
value={value}
onChange={(e) => onChange(e.target.value)}
/>
</label>
</>
);
}
function SettingsModalCheckbox({
configKey,
value,
onChange,
label,
}: {
configKey: SettKey;
value: boolean;
onChange: (value: boolean) => void;
label: string;
}) {
return (
<div className="flex flex-row items-center mb-2">
<input
type="checkbox"
className="toggle"
checked={value}
onChange={(e) => onChange(e.target.checked)}
/>
<span className="ml-4">{label || configKey}</span>
</div>
);
}

View file

@ -0,0 +1,96 @@
import { useEffect, useState } from 'react';
import { classNames } from '../utils/misc';
import { Conversation } from '../utils/types';
import StorageUtils from '../utils/storage';
import { useNavigate, useParams } from 'react-router';
export default function Sidebar() {
const params = useParams();
const navigate = useNavigate();
const [conversations, setConversations] = useState<Conversation[]>([]);
const [currConv, setCurrConv] = useState<Conversation | null>(null);
useEffect(() => {
StorageUtils.getOneConversation(params.convId ?? '').then(setCurrConv);
}, [params.convId]);
useEffect(() => {
const handleConversationChange = async () => {
setConversations(await StorageUtils.getAllConversations());
};
StorageUtils.onConversationChanged(handleConversationChange);
handleConversationChange();
return () => {
StorageUtils.offConversationChanged(handleConversationChange);
};
}, []);
return (
<>
<input
id="toggle-drawer"
type="checkbox"
className="drawer-toggle"
defaultChecked
/>
<div className="drawer-side h-screen lg:h-screen z-50 lg:max-w-64">
<label
htmlFor="toggle-drawer"
aria-label="close sidebar"
className="drawer-overlay"
></label>
<div className="flex flex-col bg-base-200 min-h-full max-w-64 py-4 px-4">
<div className="flex flex-row items-center justify-between mb-4 mt-4">
<h2 className="font-bold ml-4">Conversations</h2>
{/* close sidebar button */}
<label htmlFor="toggle-drawer" className="btn btn-ghost lg:hidden">
<svg
xmlns="http://www.w3.org/2000/svg"
width="16"
height="16"
fill="currentColor"
className="bi bi-arrow-bar-left"
viewBox="0 0 16 16"
>
<path
fillRule="evenodd"
d="M12.5 15a.5.5 0 0 1-.5-.5v-13a.5.5 0 0 1 1 0v13a.5.5 0 0 1-.5.5M10 8a.5.5 0 0 1-.5.5H3.707l2.147 2.146a.5.5 0 0 1-.708.708l-3-3a.5.5 0 0 1 0-.708l3-3a.5.5 0 1 1 .708.708L3.707 7.5H9.5a.5.5 0 0 1 .5.5"
/>
</svg>
</label>
</div>
{/* list of conversations */}
<div
className={classNames({
'btn btn-ghost justify-start': true,
'btn-active': !currConv,
})}
onClick={() => navigate('/')}
>
+ New conversation
</div>
{conversations.map((conv) => (
<div
key={conv.id}
className={classNames({
'btn btn-ghost justify-start font-normal': true,
'btn-active': conv.id === currConv?.id,
})}
onClick={() => navigate(`/chat/${conv.id}`)}
dir="auto"
>
<span className="truncate">{conv.name}</span>
</div>
))}
<div className="text-center text-xs opacity-40 mt-auto mx-4">
Conversations are saved to browser's IndexedDB
</div>
</div>
</div>
</>
);
}

View file

@ -0,0 +1,96 @@
import { useEffect, useRef, useState, useCallback } from 'react';
// Media Query for detecting "large" screens (matching Tailwind's lg: breakpoint)
const LARGE_SCREEN_MQ = '(min-width: 1024px)';
// Calculates and sets the textarea height based on its scrollHeight
const adjustTextareaHeight = (textarea: HTMLTextAreaElement | null) => {
if (!textarea) return;
// Only perform auto-sizing on large screens
if (!window.matchMedia(LARGE_SCREEN_MQ).matches) {
// On small screens, reset inline height and max-height styles.
// This allows CSS (e.g., `rows` attribute or classes) to control the height,
// and enables manual resizing if `resize-vertical` is set.
textarea.style.height = ''; // Use 'auto' or '' to reset
textarea.style.maxHeight = '';
return; // Do not adjust height programmatically on small screens
}
const computedStyle = window.getComputedStyle(textarea);
// Get the max-height specified by CSS (e.g., from `lg:max-h-48`)
const currentMaxHeight = computedStyle.maxHeight;
// Temporarily remove max-height to allow scrollHeight to be calculated correctly
textarea.style.maxHeight = 'none';
// Reset height to 'auto' to measure the actual scrollHeight needed
textarea.style.height = 'auto';
// Set the height to the calculated scrollHeight
textarea.style.height = `${textarea.scrollHeight}px`;
// Re-apply the original max-height from CSS to enforce the limit
textarea.style.maxHeight = currentMaxHeight;
};
// Interface describing the API returned by the hook
export interface ChatTextareaApi {
value: () => string;
setValue: (value: string) => void;
focus: () => void;
ref: React.RefObject<HTMLTextAreaElement>;
onInput: (event: React.FormEvent<HTMLTextAreaElement>) => void; // Input handler
}
// This is a workaround to prevent the textarea from re-rendering when the inner content changes
// See https://github.com/ggml-org/llama.cpp/pull/12299
// combined now with auto-sizing logic.
export function useChatTextarea(initValue: string): ChatTextareaApi {
const [savedInitValue, setSavedInitValue] = useState<string>(initValue);
const textareaRef = useRef<HTMLTextAreaElement>(null);
// Effect to set initial value and height on mount or when initValue changes
useEffect(() => {
const textarea = textareaRef.current;
if (textarea) {
if (typeof savedInitValue === 'string' && savedInitValue.length > 0) {
textarea.value = savedInitValue;
// Call adjustTextareaHeight - it will check screen size internally
setTimeout(() => adjustTextareaHeight(textarea), 0);
setSavedInitValue(''); // Reset after applying
} else {
// Adjust height even if there's no initial value (for initial render)
setTimeout(() => adjustTextareaHeight(textarea), 0);
}
}
}, [textareaRef, savedInitValue]); // Depend on ref and savedInitValue
const handleInput = useCallback(
(event: React.FormEvent<HTMLTextAreaElement>) => {
// Call adjustTextareaHeight on every input - it will decide whether to act
adjustTextareaHeight(event.currentTarget);
},
[]
);
return {
// Method to get the current value directly from the textarea
value: () => {
return textareaRef.current?.value ?? '';
},
// Method to programmatically set the value and trigger height adjustment
setValue: (value: string) => {
const textarea = textareaRef.current;
if (textarea) {
textarea.value = value;
// Call adjustTextareaHeight - it will check screen size internally
setTimeout(() => adjustTextareaHeight(textarea), 0);
}
},
focus: () => {
if (textareaRef.current) {
textareaRef.current.focus();
}
},
ref: textareaRef,
onInput: handleInput,
};
}

View file

@ -0,0 +1,78 @@
@use 'sass:meta';
@use 'tailwindcss';
@plugin 'daisyui' {
themes: all;
}
html {
scrollbar-gutter: auto;
}
.markdown {
h1,
h2,
h3,
h4,
h5,
h6,
ul,
ol,
li {
all: revert;
}
pre {
@apply whitespace-pre-wrap rounded-lg p-2;
border: 1px solid currentColor;
}
p {
@apply mb-2;
}
/* TODO: fix markdown table */
}
.show-on-hover {
@apply md:opacity-0 md:group-hover:opacity-100;
}
.btn-mini {
@apply cursor-pointer hover:shadow-md;
}
.chat-screen {
max-width: 900px;
}
.chat-bubble-base-300 {
--tw-bg-opacity: 1;
--tw-text-opacity: 1;
@apply bg-base-300 text-base-content;
}
/* Highlight.js */
[data-color-scheme='light'] {
@include meta.load-css('highlight.js/styles/stackoverflow-light');
.dark-color {
@apply bg-base-content text-base-100;
}
}
[data-color-scheme='dark'] {
@include meta.load-css('highlight.js/styles/stackoverflow-dark');
}
[data-color-scheme='auto'] {
@media (prefers-color-scheme: light) {
@include meta.load-css('highlight.js/styles/stackoverflow-light');
.dark-color {
@apply bg-base-content text-base-100;
}
}
@media (prefers-color-scheme: dark) {
@include meta.load-css('highlight.js/styles/stackoverflow-dark');
}
}
.hljs {
background: transparent !important;
padding: 0.5em !important;
}
.katex-display {
margin: 0 0 !important;
}

View file

@ -0,0 +1,10 @@
import { StrictMode } from 'react';
import { createRoot } from 'react-dom/client';
import './index.scss';
import App from './App.tsx';
createRoot(document.getElementById('root')!).render(
<StrictMode>
<App />
</StrictMode>
);

View file

@ -0,0 +1,387 @@
import React, { createContext, useContext, useEffect, useState } from 'react';
import {
APIMessage,
CanvasData,
Conversation,
Message,
PendingMessage,
ViewingChat,
} from './types';
import StorageUtils from './storage';
import {
filterThoughtFromMsgs,
normalizeMsgsForAPI,
getSSEStreamAsync,
} from './misc';
import { BASE_URL, CONFIG_DEFAULT, isDev } from '../Config';
import { matchPath, useLocation, useNavigate } from 'react-router';
interface AppContextValue {
// conversations and messages
viewingChat: ViewingChat | null;
pendingMessages: Record<Conversation['id'], PendingMessage>;
isGenerating: (convId: string) => boolean;
sendMessage: (
convId: string | null,
leafNodeId: Message['id'] | null,
content: string,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk
) => Promise<boolean>;
stopGenerating: (convId: string) => void;
replaceMessageAndGenerate: (
convId: string,
parentNodeId: Message['id'], // the parent node of the message to be replaced
content: string | null,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk
) => Promise<void>;
// canvas
canvasData: CanvasData | null;
setCanvasData: (data: CanvasData | null) => void;
// config
config: typeof CONFIG_DEFAULT;
saveConfig: (config: typeof CONFIG_DEFAULT) => void;
showSettings: boolean;
setShowSettings: (show: boolean) => void;
}
// this callback is used for scrolling to the bottom of the chat and switching to the last node
export type CallbackGeneratedChunk = (currLeafNodeId?: Message['id']) => void;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const AppContext = createContext<AppContextValue>({} as any);
const getViewingChat = async (convId: string): Promise<ViewingChat | null> => {
const conv = await StorageUtils.getOneConversation(convId);
if (!conv) return null;
return {
conv: conv,
// all messages from all branches, not filtered by last node
messages: await StorageUtils.getMessages(convId),
};
};
export const AppContextProvider = ({
children,
}: {
children: React.ReactElement;
}) => {
const { pathname } = useLocation();
const navigate = useNavigate();
const params = matchPath('/chat/:convId', pathname);
const convId = params?.params?.convId;
const [viewingChat, setViewingChat] = useState<ViewingChat | null>(null);
const [pendingMessages, setPendingMessages] = useState<
Record<Conversation['id'], PendingMessage>
>({});
const [aborts, setAborts] = useState<
Record<Conversation['id'], AbortController>
>({});
const [config, setConfig] = useState(StorageUtils.getConfig());
const [canvasData, setCanvasData] = useState<CanvasData | null>(null);
const [showSettings, setShowSettings] = useState(false);
// handle change when the convId from URL is changed
useEffect(() => {
// also reset the canvas data
setCanvasData(null);
const handleConversationChange = async (changedConvId: string) => {
if (changedConvId !== convId) return;
setViewingChat(await getViewingChat(changedConvId));
};
StorageUtils.onConversationChanged(handleConversationChange);
getViewingChat(convId ?? '').then(setViewingChat);
return () => {
StorageUtils.offConversationChanged(handleConversationChange);
};
}, [convId]);
const setPending = (convId: string, pendingMsg: PendingMessage | null) => {
// if pendingMsg is null, remove the key from the object
if (!pendingMsg) {
setPendingMessages((prev) => {
const newState = { ...prev };
delete newState[convId];
return newState;
});
} else {
setPendingMessages((prev) => ({ ...prev, [convId]: pendingMsg }));
}
};
const setAbort = (convId: string, controller: AbortController | null) => {
if (!controller) {
setAborts((prev) => {
const newState = { ...prev };
delete newState[convId];
return newState;
});
} else {
setAborts((prev) => ({ ...prev, [convId]: controller }));
}
};
////////////////////////////////////////////////////////////////////////
// public functions
const isGenerating = (convId: string) => !!pendingMessages[convId];
const generateMessage = async (
convId: string,
leafNodeId: Message['id'],
onChunk: CallbackGeneratedChunk
) => {
if (isGenerating(convId)) return;
const config = StorageUtils.getConfig();
const currConversation = await StorageUtils.getOneConversation(convId);
if (!currConversation) {
throw new Error('Current conversation is not found');
}
const currMessages = StorageUtils.filterByLeafNodeId(
await StorageUtils.getMessages(convId),
leafNodeId,
false
);
const abortController = new AbortController();
setAbort(convId, abortController);
if (!currMessages) {
throw new Error('Current messages are not found');
}
const pendingId = Date.now() + 1;
let pendingMsg: PendingMessage = {
id: pendingId,
convId,
type: 'text',
timestamp: pendingId,
role: 'assistant',
content: null,
parent: leafNodeId,
children: [],
};
setPending(convId, pendingMsg);
try {
// prepare messages for API
let messages: APIMessage[] = [
...(config.systemMessage.length === 0
? []
: [{ role: 'system', content: config.systemMessage } as APIMessage]),
...normalizeMsgsForAPI(currMessages),
];
if (config.excludeThoughtOnReq) {
messages = filterThoughtFromMsgs(messages);
}
if (isDev) console.log({ messages });
// prepare params
const params = {
messages,
stream: true,
cache_prompt: true,
samplers: config.samplers,
temperature: config.temperature,
dynatemp_range: config.dynatemp_range,
dynatemp_exponent: config.dynatemp_exponent,
top_k: config.top_k,
top_p: config.top_p,
min_p: config.min_p,
typical_p: config.typical_p,
xtc_probability: config.xtc_probability,
xtc_threshold: config.xtc_threshold,
repeat_last_n: config.repeat_last_n,
repeat_penalty: config.repeat_penalty,
presence_penalty: config.presence_penalty,
frequency_penalty: config.frequency_penalty,
dry_multiplier: config.dry_multiplier,
dry_base: config.dry_base,
dry_allowed_length: config.dry_allowed_length,
dry_penalty_last_n: config.dry_penalty_last_n,
max_tokens: config.max_tokens,
timings_per_token: !!config.showTokensPerSecond,
...(config.custom.length ? JSON.parse(config.custom) : {}),
};
// send request
const fetchResponse = await fetch(`${BASE_URL}/v1/chat/completions`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
...(config.apiKey
? { Authorization: `Bearer ${config.apiKey}` }
: {}),
},
body: JSON.stringify(params),
signal: abortController.signal,
});
if (fetchResponse.status !== 200) {
const body = await fetchResponse.json();
throw new Error(body?.error?.message || 'Unknown error');
}
const chunks = getSSEStreamAsync(fetchResponse);
for await (const chunk of chunks) {
// const stop = chunk.stop;
if (chunk.error) {
throw new Error(chunk.error?.message || 'Unknown error');
}
const addedContent = chunk.choices[0].delta.content;
const lastContent = pendingMsg.content || '';
if (addedContent) {
pendingMsg = {
...pendingMsg,
content: lastContent + addedContent,
};
}
const timings = chunk.timings;
if (timings && config.showTokensPerSecond) {
// only extract what's really needed, to save some space
pendingMsg.timings = {
prompt_n: timings.prompt_n,
prompt_ms: timings.prompt_ms,
predicted_n: timings.predicted_n,
predicted_ms: timings.predicted_ms,
};
}
setPending(convId, pendingMsg);
onChunk(); // don't need to switch node for pending message
}
} catch (err) {
setPending(convId, null);
if ((err as Error).name === 'AbortError') {
// user stopped the generation via stopGeneration() function
// we can safely ignore this error
} else {
console.error(err);
// eslint-disable-next-line @typescript-eslint/no-explicit-any
alert((err as any)?.message ?? 'Unknown error');
throw err; // rethrow
}
}
if (pendingMsg.content !== null) {
await StorageUtils.appendMsg(pendingMsg as Message, leafNodeId);
}
setPending(convId, null);
onChunk(pendingId); // trigger scroll to bottom and switch to the last node
};
const sendMessage = async (
convId: string | null,
leafNodeId: Message['id'] | null,
content: string,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk
): Promise<boolean> => {
if (isGenerating(convId ?? '') || content.trim().length === 0) return false;
if (convId === null || convId.length === 0 || leafNodeId === null) {
const conv = await StorageUtils.createConversation(
content.substring(0, 256)
);
convId = conv.id;
leafNodeId = conv.currNode;
// if user is creating a new conversation, redirect to the new conversation
navigate(`/chat/${convId}`);
}
const now = Date.now();
const currMsgId = now;
StorageUtils.appendMsg(
{
id: currMsgId,
timestamp: now,
type: 'text',
convId,
role: 'user',
content,
extra,
parent: leafNodeId,
children: [],
},
leafNodeId
);
onChunk(currMsgId);
try {
await generateMessage(convId, currMsgId, onChunk);
return true;
} catch (_) {
// TODO: rollback
}
return false;
};
const stopGenerating = (convId: string) => {
setPending(convId, null);
aborts[convId]?.abort();
};
// if content is undefined, we remove last assistant message
const replaceMessageAndGenerate = async (
convId: string,
parentNodeId: Message['id'], // the parent node of the message to be replaced
content: string | null,
extra: Message['extra'],
onChunk: CallbackGeneratedChunk
) => {
if (isGenerating(convId)) return;
if (content !== null) {
const now = Date.now();
const currMsgId = now;
StorageUtils.appendMsg(
{
id: currMsgId,
timestamp: now,
type: 'text',
convId,
role: 'user',
content,
extra,
parent: parentNodeId,
children: [],
},
parentNodeId
);
parentNodeId = currMsgId;
}
onChunk(parentNodeId);
await generateMessage(convId, parentNodeId, onChunk);
};
const saveConfig = (config: typeof CONFIG_DEFAULT) => {
StorageUtils.setConfig(config);
setConfig(config);
};
return (
<AppContext.Provider
value={{
isGenerating,
viewingChat,
pendingMessages,
sendMessage,
stopGenerating,
replaceMessageAndGenerate,
canvasData,
setCanvasData,
config,
saveConfig,
showSettings,
setShowSettings,
}}
>
{children}
</AppContext.Provider>
);
};
export const useAppContext = () => useContext(AppContext);

View file

@ -0,0 +1,38 @@
export const XCloseButton: React.ElementType<
React.ClassAttributes<HTMLButtonElement> &
React.HTMLAttributes<HTMLButtonElement>
> = ({ className, ...props }) => (
<button className={`btn btn-square btn-sm ${className ?? ''}`} {...props}>
<svg
xmlns="http://www.w3.org/2000/svg"
className="h-6 w-6"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth="2"
d="M6 18L18 6M6 6l12 12"
/>
</svg>
</button>
);
export const OpenInNewTab = ({
href,
children,
}: {
href: string;
children: string;
}) => (
<a
className="underline"
href={href}
target="_blank"
rel="noopener noreferrer"
>
{children}
</a>
);

View file

@ -0,0 +1,60 @@
import { useEffect, useState } from 'react';
import { MessageExtraContext } from './types';
import { ChatTextareaApi } from '../components/useChatTextarea.ts';
// Extra context when using llama.cpp WebUI from llama-vscode, inside an iframe
// Ref: https://github.com/ggml-org/llama.cpp/pull/11940
interface SetTextEvData {
text: string;
context: string;
}
/**
* To test it:
* window.postMessage({ command: 'setText', text: 'Spot the syntax error', context: 'def test()\n return 123' }, '*');
*/
export const useVSCodeContext = (textarea: ChatTextareaApi) => {
const [extraContext, setExtraContext] = useState<MessageExtraContext | null>(
null
);
// Accept setText message from a parent window and set inputMsg and extraContext
useEffect(() => {
const handleMessage = (event: MessageEvent) => {
if (event.data?.command === 'setText') {
const data: SetTextEvData = event.data;
textarea.setValue(data?.text);
if (data?.context && data.context.length > 0) {
setExtraContext({
type: 'context',
content: data.context,
});
}
textarea.focus();
}
};
window.addEventListener('message', handleMessage);
return () => window.removeEventListener('message', handleMessage);
}, [textarea]);
// Add a keydown listener that sends the "escapePressed" message to the parent window
useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => {
if (event.key === 'Escape') {
window.parent.postMessage({ command: 'escapePressed' }, '*');
}
};
window.addEventListener('keydown', handleKeyDown);
return () => window.removeEventListener('keydown', handleKeyDown);
}, []);
return {
extraContext,
// call once the user message is sent, to clear the extra context
clearExtraContext: () => setExtraContext(null),
};
};

View file

@ -0,0 +1,128 @@
// @ts-expect-error this package does not have typing
import TextLineStream from 'textlinestream';
import { APIMessage, Message } from './types';
// ponyfill for missing ReadableStream asyncIterator on Safari
import { asyncIterator } from '@sec-ant/readable-stream/ponyfill/asyncIterator';
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export const isString = (x: any) => !!x.toLowerCase;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export const isBoolean = (x: any) => x === true || x === false;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export const isNumeric = (n: any) => !isString(n) && !isNaN(n) && !isBoolean(n);
export const escapeAttr = (str: string) =>
str.replace(/>/g, '&gt;').replace(/"/g, '&quot;');
// wrapper for SSE
export async function* getSSEStreamAsync(fetchResponse: Response) {
if (!fetchResponse.body) throw new Error('Response body is empty');
const lines: ReadableStream<string> = fetchResponse.body
.pipeThrough(new TextDecoderStream())
.pipeThrough(new TextLineStream());
// @ts-expect-error asyncIterator complains about type, but it should work
for await (const line of asyncIterator(lines)) {
//if (isDev) console.log({ line });
if (line.startsWith('data:') && !line.endsWith('[DONE]')) {
const data = JSON.parse(line.slice(5));
yield data;
} else if (line.startsWith('error:')) {
const data = JSON.parse(line.slice(6));
throw new Error(data.message || 'Unknown error');
}
}
}
// copy text to clipboard
export const copyStr = (textToCopy: string) => {
// Navigator clipboard api needs a secure context (https)
if (navigator.clipboard && window.isSecureContext) {
navigator.clipboard.writeText(textToCopy);
} else {
// Use the 'out of viewport hidden text area' trick
const textArea = document.createElement('textarea');
textArea.value = textToCopy;
// Move textarea out of the viewport so it's not visible
textArea.style.position = 'absolute';
textArea.style.left = '-999999px';
document.body.prepend(textArea);
textArea.select();
document.execCommand('copy');
}
};
/**
* filter out redundant fields upon sending to API
* also format extra into text
*/
export function normalizeMsgsForAPI(messages: Readonly<Message[]>) {
return messages.map((msg) => {
let newContent = '';
for (const extra of msg.extra ?? []) {
if (extra.type === 'context') {
newContent += `${extra.content}\n\n`;
}
}
newContent += msg.content;
return {
role: msg.role,
content: newContent,
};
}) as APIMessage[];
}
/**
* recommended for DeepsSeek-R1, filter out content between <think> and </think> tags
*/
export function filterThoughtFromMsgs(messages: APIMessage[]) {
return messages.map((msg) => {
return {
role: msg.role,
content:
msg.role === 'assistant'
? msg.content.split('</think>').at(-1)!.trim()
: msg.content,
} as APIMessage;
});
}
export function classNames(classes: Record<string, boolean>): string {
return Object.entries(classes)
.filter(([_, value]) => value)
.map(([key, _]) => key)
.join(' ');
}
export const delay = (ms: number) =>
new Promise((resolve) => setTimeout(resolve, ms));
export const throttle = <T extends unknown[]>(
callback: (...args: T) => void,
delay: number
) => {
let isWaiting = false;
return (...args: T) => {
if (isWaiting) {
return;
}
callback(...args);
isWaiting = true;
setTimeout(() => {
isWaiting = false;
}, delay);
};
};
export const cleanCurrentUrl = (removeQueryParams: string[]) => {
const url = new URL(window.location.href);
removeQueryParams.forEach((param) => {
url.searchParams.delete(param);
});
window.history.replaceState({}, '', url.toString());
};

View file

@ -0,0 +1,284 @@
// coversations is stored in localStorage
// format: { [convId]: { id: string, lastModified: number, messages: [...] } }
import { CONFIG_DEFAULT } from '../Config';
import { Conversation, Message, TimingReport } from './types';
import Dexie, { Table } from 'dexie';
const event = new EventTarget();
type CallbackConversationChanged = (convId: string) => void;
let onConversationChangedHandlers: [
CallbackConversationChanged,
EventListener,
][] = [];
const dispatchConversationChange = (convId: string) => {
event.dispatchEvent(
new CustomEvent('conversationChange', { detail: { convId } })
);
};
const db = new Dexie('LlamacppWebui') as Dexie & {
conversations: Table<Conversation>;
messages: Table<Message>;
};
// https://dexie.org/docs/Version/Version.stores()
db.version(1).stores({
// Unlike SQL, you dont need to specify all properties but only the one you wish to index.
conversations: '&id, lastModified',
messages: '&id, convId, [convId+id], timestamp',
});
// convId is a string prefixed with 'conv-'
const StorageUtils = {
/**
* manage conversations
*/
async getAllConversations(): Promise<Conversation[]> {
await migrationLStoIDB().catch(console.error); // noop if already migrated
return (await db.conversations.toArray()).sort(
(a, b) => b.lastModified - a.lastModified
);
},
/**
* can return null if convId does not exist
*/
async getOneConversation(convId: string): Promise<Conversation | null> {
return (await db.conversations.where('id').equals(convId).first()) ?? null;
},
/**
* get all message nodes in a conversation
*/
async getMessages(convId: string): Promise<Message[]> {
return await db.messages.where({ convId }).toArray();
},
/**
* use in conjunction with getMessages to filter messages by leafNodeId
* includeRoot: whether to include the root node in the result
* if node with leafNodeId does not exist, return the path with the latest timestamp
*/
filterByLeafNodeId(
msgs: Readonly<Message[]>,
leafNodeId: Message['id'],
includeRoot: boolean
): Readonly<Message[]> {
const res: Message[] = [];
const nodeMap = new Map<Message['id'], Message>();
for (const msg of msgs) {
nodeMap.set(msg.id, msg);
}
let startNode: Message | undefined = nodeMap.get(leafNodeId);
if (!startNode) {
// if not found, we return the path with the latest timestamp
let latestTime = -1;
for (const msg of msgs) {
if (msg.timestamp > latestTime) {
startNode = msg;
latestTime = msg.timestamp;
}
}
}
// traverse the path from leafNodeId to root
// startNode can never be undefined here
let currNode: Message | undefined = startNode;
while (currNode) {
if (currNode.type !== 'root' || (currNode.type === 'root' && includeRoot))
res.push(currNode);
currNode = nodeMap.get(currNode.parent ?? -1);
}
res.sort((a, b) => a.timestamp - b.timestamp);
return res;
},
/**
* create a new conversation with a default root node
*/
async createConversation(name: string): Promise<Conversation> {
const now = Date.now();
const msgId = now;
const conv: Conversation = {
id: `conv-${now}`,
lastModified: now,
currNode: msgId,
name,
};
await db.conversations.add(conv);
// create a root node
await db.messages.add({
id: msgId,
convId: conv.id,
type: 'root',
timestamp: now,
role: 'system',
content: '',
parent: -1,
children: [],
});
return conv;
},
/**
* if convId does not exist, throw an error
*/
async appendMsg(
msg: Exclude<Message, 'parent' | 'children'>,
parentNodeId: Message['id']
): Promise<void> {
if (msg.content === null) return;
const { convId } = msg;
await db.transaction('rw', db.conversations, db.messages, async () => {
const conv = await StorageUtils.getOneConversation(convId);
const parentMsg = await db.messages
.where({ convId, id: parentNodeId })
.first();
// update the currNode of conversation
if (!conv) {
throw new Error(`Conversation ${convId} does not exist`);
}
if (!parentMsg) {
throw new Error(
`Parent message ID ${parentNodeId} does not exist in conversation ${convId}`
);
}
await db.conversations.update(convId, {
lastModified: Date.now(),
currNode: msg.id,
});
// update parent
await db.messages.update(parentNodeId, {
children: [...parentMsg.children, msg.id],
});
// create message
await db.messages.add({
...msg,
parent: parentNodeId,
children: [],
});
});
dispatchConversationChange(convId);
},
/**
* remove conversation by id
*/
async remove(convId: string): Promise<void> {
await db.transaction('rw', db.conversations, db.messages, async () => {
await db.conversations.delete(convId);
await db.messages.where({ convId }).delete();
});
dispatchConversationChange(convId);
},
// event listeners
onConversationChanged(callback: CallbackConversationChanged) {
const fn = (e: Event) => callback((e as CustomEvent).detail.convId);
onConversationChangedHandlers.push([callback, fn]);
event.addEventListener('conversationChange', fn);
},
offConversationChanged(callback: CallbackConversationChanged) {
const fn = onConversationChangedHandlers.find(([cb, _]) => cb === callback);
if (fn) {
event.removeEventListener('conversationChange', fn[1]);
}
onConversationChangedHandlers = [];
},
// manage config
getConfig(): typeof CONFIG_DEFAULT {
const savedVal = JSON.parse(localStorage.getItem('config') || '{}');
// to prevent breaking changes in the future, we always provide default value for missing keys
return {
...CONFIG_DEFAULT,
...savedVal,
};
},
setConfig(config: typeof CONFIG_DEFAULT) {
localStorage.setItem('config', JSON.stringify(config));
},
getTheme(): string {
return localStorage.getItem('theme') || 'auto';
},
setTheme(theme: string) {
if (theme === 'auto') {
localStorage.removeItem('theme');
} else {
localStorage.setItem('theme', theme);
}
},
};
export default StorageUtils;
// Migration from localStorage to IndexedDB
// these are old types, LS prefix stands for LocalStorage
interface LSConversation {
id: string; // format: `conv-{timestamp}`
lastModified: number; // timestamp from Date.now()
messages: LSMessage[];
}
interface LSMessage {
id: number;
role: 'user' | 'assistant' | 'system';
content: string;
timings?: TimingReport;
}
async function migrationLStoIDB() {
if (localStorage.getItem('migratedToIDB')) return;
const res: LSConversation[] = [];
for (const key in localStorage) {
if (key.startsWith('conv-')) {
res.push(JSON.parse(localStorage.getItem(key) ?? '{}'));
}
}
if (res.length === 0) return;
await db.transaction('rw', db.conversations, db.messages, async () => {
let migratedCount = 0;
for (const conv of res) {
const { id: convId, lastModified, messages } = conv;
const firstMsg = messages[0];
const lastMsg = messages.at(-1);
if (messages.length < 2 || !firstMsg || !lastMsg) {
console.log(
`Skipping conversation ${convId} with ${messages.length} messages`
);
continue;
}
const name = firstMsg.content ?? '(no messages)';
await db.conversations.add({
id: convId,
lastModified,
currNode: lastMsg.id,
name,
});
const rootId = messages[0].id - 2;
await db.messages.add({
id: rootId,
convId: convId,
type: 'root',
timestamp: rootId,
role: 'system',
content: '',
parent: -1,
children: [firstMsg.id],
});
for (let i = 0; i < messages.length; i++) {
const msg = messages[i];
await db.messages.add({
...msg,
type: 'text',
convId: convId,
timestamp: msg.id,
parent: i === 0 ? rootId : messages[i - 1].id,
children: i === messages.length - 1 ? [] : [messages[i + 1].id],
});
}
migratedCount++;
console.log(
`Migrated conversation ${convId} with ${messages.length} messages`
);
}
console.log(
`Migrated ${migratedCount} conversations from localStorage to IndexedDB`
);
localStorage.setItem('migratedToIDB', '1');
});
}

View file

@ -0,0 +1,91 @@
export interface TimingReport {
prompt_n: number;
prompt_ms: number;
predicted_n: number;
predicted_ms: number;
}
/**
* What is conversation "branching"? It is a feature that allows the user to edit an old message in the history, while still keeping the conversation flow.
* Inspired by ChatGPT / Claude / Hugging Chat where you edit a message, a new branch of the conversation is created, and the old message is still visible.
*
* We use the same node-based structure like other chat UIs, where each message has a parent and children. A "root" message is the first message in a conversation, which will not be displayed in the UI.
*
* root
* message 1
* message 2
* message 3
* message 4
* message 5
*
* In the above example, assuming that user wants to edit message 2, a new branch will be created:
*
* message 2
* message 3
* message 6
*
* Message 2 and 6 are siblings, and message 6 is the new branch.
*
* We only need to know the last node (aka leaf) to get the current branch. In the above example, message 5 is the leaf of branch containing message 4 and 5.
*
* For the implementation:
* - StorageUtils.getMessages() returns list of all nodes
* - StorageUtils.filterByLeafNodeId() filters the list of nodes from a given leaf node
*/
// Note: the term "message" and "node" are used interchangeably in this context
export interface Message {
id: number;
convId: string;
type: 'text' | 'root';
timestamp: number; // timestamp from Date.now()
role: 'user' | 'assistant' | 'system';
content: string;
timings?: TimingReport;
extra?: MessageExtra[];
// node based system for branching
parent: Message['id'];
children: Message['id'][];
}
type MessageExtra = MessageExtraTextFile | MessageExtraContext; // TODO: will add more in the future
export interface MessageExtraTextFile {
type: 'textFile';
name: string;
content: string;
}
export interface MessageExtraContext {
type: 'context';
content: string;
}
export type APIMessage = Pick<Message, 'role' | 'content'>;
export interface Conversation {
id: string; // format: `conv-{timestamp}`
lastModified: number; // timestamp from Date.now()
currNode: Message['id']; // the current message node being viewed
name: string;
}
export interface ViewingChat {
conv: Readonly<Conversation>;
messages: Readonly<Message[]>;
}
export type PendingMessage = Omit<Message, 'content'> & {
content: string | null;
};
export enum CanvasType {
PY_INTERPRETER,
}
export interface CanvasPyInterpreter {
type: CanvasType.PY_INTERPRETER;
content: string;
}
export type CanvasData = CanvasPyInterpreter;

1
tools/server/webui/src/vite-env.d.ts vendored Normal file
View file

@ -0,0 +1 @@
/// <reference types="vite/client" />

View file

@ -0,0 +1,16 @@
/** @type {import('tailwindcss').Config} */
export default {
content: [
"./index.html",
"./src/**/*.{js,ts,jsx,tsx}",
],
theme: {
extend: {},
},
plugins: [
require('daisyui'),
],
daisyui: {
themes: ['light', 'dark', 'cupcake', 'bumblebee', 'emerald', 'corporate', 'synthwave', 'retro', 'cyberpunk', 'valentine', 'halloween', 'garden', 'forest', 'aqua', 'lofi', 'pastel', 'fantasy', 'wireframe', 'black', 'luxury', 'dracula', 'cmyk', 'autumn', 'business', 'acid', 'lemonade', 'night', 'coffee', 'winter', 'dim', 'nord', 'sunset'],
}
}

View file

@ -0,0 +1,26 @@
{
"compilerOptions": {
"tsBuildInfoFile": "./node_modules/.tmp/tsconfig.app.tsbuildinfo",
"target": "ES2021",
"useDefineForClassFields": true,
"lib": ["ES2021", "DOM", "DOM.Iterable"],
"module": "ESNext",
"skipLibCheck": true,
/* Bundler mode */
"moduleResolution": "bundler",
"allowImportingTsExtensions": true,
"isolatedModules": true,
"moduleDetection": "force",
"noEmit": true,
"jsx": "react-jsx",
/* Linting */
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"noFallthroughCasesInSwitch": true,
"noUncheckedSideEffectImports": true
},
"include": ["src"]
}

View file

@ -0,0 +1,7 @@
{
"files": [],
"references": [
{ "path": "./tsconfig.app.json" },
{ "path": "./tsconfig.node.json" }
]
}

Some files were not shown because too many files have changed in this diff Show more