-
Notifications
You must be signed in to change notification settings - Fork 42
Expand file tree
/
Copy pathtq
More file actions
executable file
·552 lines (478 loc) · 21.7 KB
/
tq
File metadata and controls
executable file
·552 lines (478 loc) · 21.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
#!/usr/bin/env python3
"""
tq — TurboQuant CLI
Unified command-line interface for KV cache compression.
Designed for both humans and AI agents (JSON output by default).
Usage:
tq quantize <input> [--type TYPE] [--output FILE]
tq bench [--seq-len N] [--head-dim N] [--json]
tq info [--type TYPE]
tq demo [--question TEXT]
tq +compare # A/B test helper
tq +memory <model> <context> # Memory savings calculator
Google CLI design principles applied:
- JSON-first output (--json flag, default for scripts)
- Structured exit codes
- Help-driven discovery
- + prefix for high-level helpers
"""
import sys
import os
import json
import argparse
import time
import struct
try:
import numpy as np # optional — only used by bench/compare
except ImportError:
np = None
# Add bindings to path
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "../bindings/python"))
# ═══════════════════════════════════════════════════════════
# Colors (disabled when piped)
# ═══════════════════════════════════════════════════════════
IS_TTY = sys.stdout.isatty()
class C:
if IS_TTY:
BOLD = "\033[1m"; DIM = "\033[2m"; NC = "\033[0m"
CYAN = "\033[36m"; GREEN = "\033[32m"; YELLOW = "\033[33m"
RED = "\033[31m"; MAGENTA = "\033[35m"; BLUE = "\033[34m"
BAR = "█"; BAR_E = "░"
else:
BOLD = DIM = NC = CYAN = GREEN = YELLOW = RED = MAGENTA = BLUE = ""
BAR = "#"; BAR_E = "-"
def bar(val, mx, w=25):
f = int(val / mx * w) if mx > 0 else 0
f = min(f, w)
return f"{C.GREEN}{C.BAR * f}{C.DIM}{C.BAR_E * (w - f)}{C.NC}"
def sz(b):
if b >= 1e9: return f"{b/1e9:.2f} GB"
if b >= 1e6: return f"{b/1e6:.1f} MB"
if b >= 1e3: return f"{b/1e3:.1f} KB"
return f"{b} B"
# ═══════════════════════════════════════════════════════════
# EXIT CODES (structured, Google CLI pattern)
# ═══════════════════════════════════════════════════════════
EXIT_OK = 0
EXIT_USAGE = 1
EXIT_LIB_MISSING = 2
EXIT_MODEL_ERROR = 3
EXIT_IO_ERROR = 4
# ═══════════════════════════════════════════════════════════
# Ollama-style model registry (short alias → Python registry key)
# ═══════════════════════════════════════════════════════════
# User-friendly short names. Maps to quantcpp.* registry keys.
MODEL_ALIASES = {
"smollm2": "SmolLM2-135M",
"smollm2:135m": "SmolLM2-135M",
"qwen3.5": "Qwen3.5-0.8B",
"qwen3.5:0.8b": "Qwen3.5-0.8B",
"llama3.2": "Llama-3.2-1B",
"llama3.2:1b": "Llama-3.2-1B",
}
def resolve_model_name(name):
"""Resolve user input to canonical registry key.
Accepts:
- short alias (llama3.2:1b)
- canonical key (Llama-3.2-1B)
- local .gguf path
"""
if name is None:
return None
# Local file path takes precedence
if os.path.exists(name) and name.endswith(".gguf"):
return name
# Short alias
lower = name.lower()
if lower in MODEL_ALIASES:
return MODEL_ALIASES[lower]
return name # try as-is (might match canonical key)
def _load_quantcpp():
"""Import quantcpp bindings, exit with helpful error if missing."""
try:
import quantcpp
return quantcpp
except ImportError as e:
print(f"{C.RED}error:{C.NC} quantcpp bindings not importable: {e}", file=sys.stderr)
print(f" install: {C.CYAN}pip install quantcpp{C.NC}", file=sys.stderr)
print(f" or dev: {C.CYAN}cd bindings/python && pip install -e .{C.NC}", file=sys.stderr)
sys.exit(EXIT_LIB_MISSING)
def _find_quant_binary():
"""Locate the ./build/quant binary relative to this script."""
here = os.path.dirname(os.path.abspath(__file__))
project = os.path.dirname(here)
candidates = [
os.path.join(project, "build", "quant"),
os.path.join(project, "build_metal", "quant"),
"quant", # in PATH
]
for c in candidates:
if os.path.isfile(c) and os.access(c, os.X_OK):
return c
# shutil.which fallback
import shutil
found = shutil.which("quant")
if found:
return found
return None
def _find_quant_server_binary():
here = os.path.dirname(os.path.abspath(__file__))
project = os.path.dirname(here)
candidates = [
os.path.join(project, "build", "quant-server"),
os.path.join(project, "build_metal", "quant-server"),
"quant-server",
]
for c in candidates:
if os.path.isfile(c) and os.access(c, os.X_OK):
return c
import shutil
return shutil.which("quant-server")
# ═══════════════════════════════════════════════════════════
# Ollama-style commands: pull / list / run / serve
# ═══════════════════════════════════════════════════════════
def cmd_pull(args):
"""Download a model by short alias or canonical name."""
quantcpp = _load_quantcpp()
name = resolve_model_name(args.model)
# Check if it's a local path — already present, nothing to do
if os.path.exists(name) and name.endswith(".gguf"):
print(f"{C.GREEN}already local:{C.NC} {name}")
return EXIT_OK
if name not in quantcpp._MODEL_REGISTRY:
avail = ", ".join(sorted(quantcpp._MODEL_REGISTRY.keys()))
aliases = ", ".join(sorted(MODEL_ALIASES.keys()))
print(f"{C.RED}unknown model:{C.NC} {args.model!r}", file=sys.stderr)
print(f" registry: {avail}", file=sys.stderr)
print(f" aliases: {aliases}", file=sys.stderr)
return EXIT_USAGE
print(f"{C.CYAN}pulling{C.NC} {name}...")
try:
path = quantcpp.download(name)
size_mb = os.path.getsize(path) / (1024 * 1024)
print(f"{C.GREEN}✓{C.NC} {name} → {path} ({size_mb:.0f} MB)")
return EXIT_OK
except Exception as e:
print(f"{C.RED}download failed:{C.NC} {e}", file=sys.stderr)
return EXIT_IO_ERROR
def cmd_list(args):
"""List cached models and registry availability."""
quantcpp = _load_quantcpp()
cache_dir = quantcpp._CACHE_DIR
registry = quantcpp._MODEL_REGISTRY
rows = [] # (status, name, alias, size_mb, path)
for name, (repo, filename, approx_mb) in sorted(registry.items()):
path = cache_dir / filename
if path.exists():
size_mb = path.stat().st_size / (1024 * 1024)
status = "cached"
else:
size_mb = approx_mb
status = "remote"
# find alias
alias = next((a for a, n in MODEL_ALIASES.items() if n == name and ":" in a), "")
rows.append((status, name, alias, size_mb, str(path) if status == "cached" else f"~{approx_mb} MB"))
if args.json_output:
print(json.dumps([
{"status": s, "name": n, "alias": a, "size_mb": round(sz, 1), "path": p}
for (s, n, a, sz, p) in rows
], indent=2))
return EXIT_OK
print(f"\n {C.BOLD}Models{C.NC} cache: {cache_dir}\n")
print(f" {C.BOLD}{'STATUS':<8} {'NAME':<16} {'ALIAS':<14} {'SIZE':>8}{C.NC}")
print(f" {'─'*8} {'─'*16} {'─'*14} {'─'*8}")
for status, name, alias, size_mb, path in rows:
color = C.GREEN if status == "cached" else C.DIM
size_str = f"{size_mb:.0f} MB"
print(f" {color}{status:<8}{C.NC} {name:<16} {C.DIM}{alias:<14}{C.NC} {size_str:>8}")
print()
return EXIT_OK
def cmd_run(args):
"""Run an interactive chat with a model (auto-pull if needed)."""
quantcpp = _load_quantcpp()
name = resolve_model_name(args.model)
# Resolve to local path (pull if needed)
if os.path.exists(name) and name.endswith(".gguf"):
model_path = name
elif name in quantcpp._MODEL_REGISTRY:
repo, filename, _ = quantcpp._MODEL_REGISTRY[name]
cached = quantcpp._CACHE_DIR / filename
if not cached.exists():
print(f"{C.CYAN}model not cached — pulling{C.NC} {name}")
try:
model_path = quantcpp.download(name)
except Exception as e:
print(f"{C.RED}pull failed:{C.NC} {e}", file=sys.stderr)
return EXIT_IO_ERROR
else:
model_path = str(cached)
else:
avail = ", ".join(sorted(quantcpp._MODEL_REGISTRY.keys()))
print(f"{C.RED}unknown model:{C.NC} {args.model!r}", file=sys.stderr)
print(f" available: {avail}", file=sys.stderr)
return EXIT_USAGE
binary = _find_quant_binary()
if not binary:
print(f"{C.RED}quant binary not found:{C.NC} run `cmake --build build` first", file=sys.stderr)
return EXIT_LIB_MISSING
cmd = [binary, model_path, "--chat"]
if args.prompt:
cmd += ["-p", args.prompt]
cmd += ["-j", str(args.threads)]
cmd += ["-n", str(args.max_tokens)]
print(f"{C.DIM}→ {' '.join(cmd)}{C.NC}")
os.execvp(cmd[0], cmd)
def cmd_serve(args):
"""Start OpenAI-compatible HTTP server (auto-pull if needed)."""
quantcpp = _load_quantcpp()
name = resolve_model_name(args.model)
if os.path.exists(name) and name.endswith(".gguf"):
model_path = name
elif name in quantcpp._MODEL_REGISTRY:
repo, filename, _ = quantcpp._MODEL_REGISTRY[name]
cached = quantcpp._CACHE_DIR / filename
if not cached.exists():
print(f"{C.CYAN}model not cached — pulling{C.NC} {name}")
try:
model_path = quantcpp.download(name)
except Exception as e:
print(f"{C.RED}pull failed:{C.NC} {e}", file=sys.stderr)
return EXIT_IO_ERROR
else:
model_path = str(cached)
else:
print(f"{C.RED}unknown model:{C.NC} {args.model!r}", file=sys.stderr)
return EXIT_USAGE
binary = _find_quant_server_binary()
if not binary:
print(f"{C.RED}quant-server binary not found:{C.NC} build with "
f"`cmake -B build -DTQ_BUILD_SERVER=ON && cmake --build build`",
file=sys.stderr)
return EXIT_LIB_MISSING
cmd = [binary, model_path, "-p", str(args.port), "-j", str(args.threads)]
print(f"{C.GREEN}quant serve{C.NC} {name} on :{args.port}")
print(f"{C.DIM}→ {' '.join(cmd)}{C.NC}")
os.execvp(cmd[0], cmd)
# ═══════════════════════════════════════════════════════════
# COMMANDS
# ═══════════════════════════════════════════════════════════
def cmd_info(args):
"""Show quantization type information."""
try:
from turboquant import TurboQuant
tq = TurboQuant("cpu")
except Exception as e:
print(json.dumps({"error": "TurboQuant library not found", "detail": str(e)}))
return EXIT_LIB_MISSING
types = [
{"name": "uniform_4b", "id": 5, "bits": 4.2, "compression": 7.5, "grade": "A+", "recommended": True},
{"name": "mixed_4b8", "id": 7, "bits": 5.0, "compression": 6.4, "grade": "A+", "recommended": True},
{"name": "uniform_2b", "id": 6, "bits": 2.2, "compression": 14.2, "grade": "A", "recommended": False},
{"name": "turbo_3b", "id": 3, "bits": 5.8, "compression": 4.6, "grade": "B+", "recommended": False},
{"name": "polar_4b", "id": 1, "bits": 4.5, "compression": 7.1, "grade": "B", "recommended": False},
{"name": "qjl_1b", "id": 2, "bits": 1.2, "compression": 25.6,"grade": "C", "recommended": False},
]
if args.json_output:
print(json.dumps({"types": types}, indent=2))
else:
print(f"\n {C.BOLD}TurboQuant Quantization Types{C.NC}")
print(f" Ranked by real Qwen3.5-0.8B A/B test results\n")
print(f" {C.BOLD}{'Type':<14} {'Bits':>5} {'Compress':>9} {'Grade':>6} {'Note':<20}{C.NC}")
print(f" {'─'*14} {'─'*5} {'─'*9} {'─'*6} {'─'*20}")
for t in types:
star = f"{C.GREEN}★{C.NC}" if t["recommended"] else " "
note = "← recommended" if t["recommended"] else ""
print(f" {star} {t['name']:<12} {t['bits']:>5.1f} {t['compression']:>8.1f}x {t['grade']:>5} {note}")
print()
return EXIT_OK
def cmd_bench(args):
"""Run performance benchmark."""
try:
from turboquant import TurboQuant
tq = TurboQuant("cpu")
except Exception as e:
print(json.dumps({"error": str(e)}))
return EXIT_LIB_MISSING
seq_len = args.seq_len or 512
head_dim = args.head_dim or 128
reps = 500
np.random.seed(42)
keys = np.random.randn(seq_len, head_dim).astype(np.float32) * 0.15
query = np.random.randn(head_dim).astype(np.float32) * 0.15
results = []
for qtype, name in [(5, "uniform_4b"), (7, "mixed_4b8"), (6, "uniform_2b")]:
t0 = time.time()
for _ in range(reps):
q = tq.quantize_keys(keys, qtype)
quant_time = (time.time() - t0) / reps
deq = tq.dequantize_keys(q, seq_len, head_dim, qtype)
mse = float(np.mean((keys - deq) ** 2))
fp32_scores = keys @ query
scores = tq.attention(query, q, seq_len, head_dim, qtype)
cos = float(np.dot(scores, fp32_scores) / (np.linalg.norm(scores) * np.linalg.norm(fp32_scores) + 1e-10))
results.append({
"type": name, "seq_len": seq_len, "head_dim": head_dim,
"mse": round(mse, 6), "cosine": round(cos, 4),
"quant_ms": round(quant_time * 1000, 3),
"compression": round(keys.nbytes / len(q), 1),
})
if args.json_output:
print(json.dumps({"benchmark": results}, indent=2))
else:
print(f"\n {C.BOLD}TurboQuant Benchmark{C.NC} (seq={seq_len}, dim={head_dim})\n")
print(f" {C.BOLD}{'Type':<14} {'MSE':>10} {'Cosine':>8} {'Time':>8} {'Compress':>9}{C.NC}")
print(f" {'─'*14} {'─'*10} {'─'*8} {'─'*8} {'─'*9}")
for r in results:
g = C.GREEN if r["cosine"] > 0.99 else C.YELLOW if r["cosine"] > 0.95 else C.RED
print(f" {r['type']:<14} {r['mse']:>10.6f} {g}{r['cosine']:>8.4f}{C.NC} {r['quant_ms']:>6.1f}ms {r['compression']:>7.1f}x")
print()
return EXIT_OK
def cmd_memory(args):
"""Calculate memory savings for a model+context combination."""
models = {
"qwen3.5-0.8b": {"layers": 6, "kv_heads": 2, "head_dim": 256, "params": 0.8},
"llama-3.2-1b": {"layers": 16, "kv_heads": 8, "head_dim": 64, "params": 1.2},
"llama-3.2-3b": {"layers": 28, "kv_heads": 8, "head_dim": 128, "params": 3.2},
"phi-3-mini": {"layers": 32, "kv_heads": 32, "head_dim": 96, "params": 3.8},
}
model_key = args.model.lower().replace(" ", "-")
if model_key not in models:
avail = ", ".join(models.keys())
if args.json_output:
print(json.dumps({"error": f"Unknown model: {args.model}", "available": list(models.keys())}))
else:
print(f" {C.RED}Unknown model: {args.model}{C.NC}")
print(f" Available: {avail}")
return EXIT_USAGE
m = models[model_key]
ctx = args.context
fp16 = m["layers"] * m["kv_heads"] * m["head_dim"] * ctx * 2 * 2
tq4b = fp16 * 4.2 / 16
k4v2 = fp16 * (4.2 + 2.2) / 2 / 16
tq2b = fp16 * 2.2 / 16
result = {
"model": args.model, "context": ctx,
"fp16_bytes": int(fp16),
"uniform_4b_bytes": int(tq4b),
"k4v2_bytes": int(k4v2),
"uniform_2b_bytes": int(tq2b),
"saved_k4v2_bytes": int(fp16 - k4v2),
"saved_pct": round((1 - k4v2 / fp16) * 100, 1),
}
if args.json_output:
print(json.dumps(result, indent=2))
else:
ctx_str = f"{ctx//1024}K" if ctx >= 1024 else str(ctx)
print(f"\n {C.BOLD}Memory: {args.model} @ {ctx_str} context{C.NC}\n")
configs = [
("FP16 (baseline)", fp16, C.RED),
("TQ uniform_4b", tq4b, C.GREEN),
("TQ K4V2", k4v2, C.GREEN),
("TQ uniform_2b", tq2b, C.YELLOW),
]
for name, size, color in configs:
comp = fp16 / size if size > 0 else 1
print(f" {name:<20} {sz(size):>10} {comp:>5.1f}x {bar(size, fp16)}")
print(f"\n {C.GREEN}{C.BOLD}Best balance (K4V2): saves {sz(fp16 - k4v2)} ({(1-k4v2/fp16)*100:.0f}%){C.NC}\n")
return EXIT_OK
def cmd_compare(args):
"""A/B comparison helper."""
os.execvp(sys.executable, [sys.executable, "-c",
"import subprocess; subprocess.run(['./build/ab_test'])"])
# ═══════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════
def main():
parser = argparse.ArgumentParser(
prog="tq",
description="TurboQuant CLI — KV cache compression for LLM inference",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
commands:
pull MODEL Download a model (e.g., llama3.2:1b)
list List cached and available models
run MODEL [PROMPT] Chat with a model (auto-pulls if needed)
serve MODEL Start OpenAI-compatible HTTP server
info Show quantization types and recommendations
bench Run performance benchmark
+memory MODEL CTX Calculate memory savings
+compare Run A/B comparison (requires build)
demo Chat with Qwen3.5-0.8B (legacy, use `run` instead)
examples:
tq pull llama3.2:1b
tq list
tq run llama3.2:1b
tq run llama3.2:1b "What is gravity?"
tq serve llama3.2:1b --port 8080
tq info --json
tq bench --seq-len 2048 --head-dim 256
""")
parser.add_argument("--json", dest="json_output", action="store_true", help="JSON output (for AI agents)")
sub = parser.add_subparsers(dest="command")
# pull
p_pull = sub.add_parser("pull", help="Download a model from HuggingFace")
p_pull.add_argument("model", help="Model name or alias (e.g., llama3.2:1b)")
# list
p_list = sub.add_parser("list", help="List cached and available models")
p_list.add_argument("--json", dest="json_output", action="store_true")
# run
p_run = sub.add_parser("run", help="Chat with a model (auto-pulls if needed)")
p_run.add_argument("model", help="Model name or alias")
p_run.add_argument("prompt", nargs="?", default=None, help="Optional prompt (interactive if omitted)")
p_run.add_argument("-j", "--threads", type=int, default=4)
p_run.add_argument("-n", "--max-tokens", type=int, default=256)
# serve
p_serve = sub.add_parser("serve", help="Start OpenAI-compatible HTTP server")
p_serve.add_argument("model", help="Model name or alias")
p_serve.add_argument("-p", "--port", type=int, default=8080)
p_serve.add_argument("-j", "--threads", type=int, default=4)
# info
p_info = sub.add_parser("info", help="Quantization type information")
p_info.add_argument("--json", dest="json_output", action="store_true")
# bench
p_bench = sub.add_parser("bench", help="Performance benchmark")
p_bench.add_argument("--seq-len", type=int)
p_bench.add_argument("--head-dim", type=int)
p_bench.add_argument("--json", dest="json_output", action="store_true")
# +memory
p_mem = sub.add_parser("+memory", help="Memory savings calculator")
p_mem.add_argument("model", help="Model name (e.g., llama-3.2-3b)")
p_mem.add_argument("context", type=int, help="Context length in tokens")
p_mem.add_argument("--json", dest="json_output", action="store_true")
# +compare
sub.add_parser("+compare", help="Run A/B comparison")
# demo
p_demo = sub.add_parser("demo", help="Chat with Qwen3.5-0.8B (native C engine)")
p_demo.add_argument("question", nargs="?", help="Question (interactive if omitted)")
p_demo.add_argument("--engine", choices=["native", "pytorch"], default="native",
help="Inference engine: native (quant, default) or pytorch")
args = parser.parse_args()
if not args.command:
parser.print_help()
return EXIT_USAGE
if args.command == "pull":
return cmd_pull(args)
elif args.command == "list":
return cmd_list(args)
elif args.command == "run":
return cmd_run(args)
elif args.command == "serve":
return cmd_serve(args)
elif args.command == "info":
return cmd_info(args)
elif args.command == "bench":
return cmd_bench(args)
elif args.command == "+memory":
return cmd_memory(args)
elif args.command == "+compare":
return cmd_compare(args)
elif args.command == "demo":
demo_args = ["--engine", args.engine]
if args.question:
demo_args.append(args.question)
os.execvp(sys.executable, [sys.executable,
os.path.join(os.path.dirname(__file__), "tq_chat.py"),
*demo_args])
return EXIT_OK
if __name__ == "__main__":
sys.exit(main())