diff --git a/workers/tests/test_gpu.py b/workers/tests/test_gpu.py new file mode 100644 index 0000000..42952d0 --- /dev/null +++ b/workers/tests/test_gpu.py @@ -0,0 +1,47 @@ +from unittest.mock import patch, call +from void_workers import gpu, config + + +def test_free_unloads_each_loaded_model(monkeypatch): + monkeypatch.setattr(config, "OLLAMA_FREE_BEFORE_STT", True) + calls = [] + + def fake_http(method, url, body=None, timeout=5): + calls.append((method, url, body)) + if url.endswith("/api/ps"): + # loaded first, then empty after the unloads (confirm-poll) + return {"models": [{"name": "llama3.1:8b"}]} if len([c for c in calls if c[1].endswith("/api/ps")]) == 1 else {"models": []} + return {} + + with patch("void_workers.gpu._http", side_effect=fake_http): + freed = gpu.free_ollama_vram(base="http://x:11434") + + assert freed == ["llama3.1:8b"] + # an unload POST with keep_alive:0 was issued for the loaded model + assert ( + "POST", + "http://x:11434/api/generate", + {"model": "llama3.1:8b", "keep_alive": 0}, + ) in calls + + +def test_free_is_noop_when_disabled(monkeypatch): + monkeypatch.setattr(config, "OLLAMA_FREE_BEFORE_STT", False) + with patch("void_workers.gpu._http") as h: + assert gpu.free_ollama_vram(base="http://x:11434") == [] + h.assert_not_called() + + +def test_free_is_noop_when_nothing_loaded(monkeypatch): + monkeypatch.setattr(config, "OLLAMA_FREE_BEFORE_STT", True) + with patch("void_workers.gpu._http", return_value={"models": []}) as h: + assert gpu.free_ollama_vram(base="http://x:11434") == [] + # only the /api/ps probe, no unload POST + assert all(c.args[0] == "GET" for c in h.call_args_list) + + +def test_free_never_raises_when_ollama_unreachable(monkeypatch): + monkeypatch.setattr(config, "OLLAMA_FREE_BEFORE_STT", True) + with patch("void_workers.gpu._http", side_effect=OSError("connection refused")): + # ps fails -> [] -> no unload -> returns [] without propagating + assert gpu.free_ollama_vram(base="http://x:11434") == [] diff --git a/workers/tests/test_model.py b/workers/tests/test_model.py index a683258..18b16a5 100644 --- a/workers/tests/test_model.py +++ b/workers/tests/test_model.py @@ -15,10 +15,24 @@ def test_model_returns_singleton(monkeypatch): def test_uses_gpu_when_available(monkeypatch): monkeypatch.setattr(model, "_whisper_model", None) with patch("void_workers.model.cuda_available", return_value=True): - with patch("faster_whisper.WhisperModel", return_value=MagicMock()) as WM: - model.whisper_model() - assert WM.call_args.kwargs["device"] == "cuda" - assert WM.call_args.kwargs["compute_type"] == "float16" + with patch("void_workers.gpu.free_ollama_vram", return_value=[]): + with patch("faster_whisper.WhisperModel", return_value=MagicMock()) as WM: + model.whisper_model() + assert WM.call_args.kwargs["device"] == "cuda" + assert WM.call_args.kwargs["compute_type"] == "float16" + + +def test_frees_ollama_before_gpu_load(monkeypatch): + # Ollama VRAM must be freed BEFORE the cuda model is constructed. + monkeypatch.setattr(model, "_whisper_model", None) + order = [] + with patch("void_workers.model.cuda_available", return_value=True): + with patch("void_workers.gpu.free_ollama_vram", + side_effect=lambda *a, **k: order.append("free")): + with patch("faster_whisper.WhisperModel", + side_effect=lambda *a, **k: order.append("load") or MagicMock()): + model.whisper_model() + assert order == ["free", "load"] def test_falls_back_to_cpu_when_cuda_load_fails(monkeypatch): @@ -33,8 +47,9 @@ def test_falls_back_to_cpu_when_cuda_load_fails(monkeypatch): return cpu_model with patch("void_workers.model.cuda_available", return_value=True): - with patch("faster_whisper.WhisperModel", side_effect=fake_ctor): - got = model.whisper_model() + with patch("void_workers.gpu.free_ollama_vram", return_value=[]): + with patch("faster_whisper.WhisperModel", side_effect=fake_ctor): + got = model.whisper_model() assert got is cpu_model diff --git a/workers/void_workers/config.py b/workers/void_workers/config.py index 945d958..4bd6802 100644 --- a/workers/void_workers/config.py +++ b/workers/void_workers/config.py @@ -15,6 +15,12 @@ WHISPER_MODEL = env("WHISPER_MODEL", "small.en") WHISPER_CACHE = env("WHISPER_CACHE", "/var/lib/void/whisper-models") ALLOW_PRIVATE = env("VOID_INGEST_ALLOW_PRIVATE", "false") == "true" +# GPU sharing: Whisper and Ollama (CT 102) share one A2000. Before loading +# Whisper on the GPU, ask Ollama to unload its models to make room (it reloads +# cooperatively on its next request). Best-effort; CPU fallback covers failure. +OLLAMA_URL = env("OLLAMA_URL", "http://192.168.1.185:11434") +OLLAMA_FREE_BEFORE_STT = env("OLLAMA_FREE_BEFORE_STT", "true") == "true" + CONCURRENCY = { "extract.pdf": env_int("VOID_CONCURRENCY_EXTRACT_PDF", 2), "extract.image": env_int("VOID_CONCURRENCY_EXTRACT_IMAGE", 2), diff --git a/workers/void_workers/gpu.py b/workers/void_workers/gpu.py new file mode 100644 index 0000000..1d62f11 --- /dev/null +++ b/workers/void_workers/gpu.py @@ -0,0 +1,65 @@ +"""Cooperative GPU sharing with Ollama. + +Whisper (this worker, CT 311) and Ollama (CT 102) both pass through Z's single +RTX A2000. Before Whisper loads on the GPU we ask Ollama to unload its models so +there's room; Ollama transparently reloads on its next request. Everything here +is best-effort and never raises — if Ollama is unreachable or slow, Whisper +still tries the GPU and falls back to CPU (see model.py). + +Stdlib urllib only (the workers carry no `requests`/`httpx` dependency). +""" +import json +import time +import urllib.request + +from .log import log +from . import config + + +def _http(method, url, body=None, timeout=5): + data = json.dumps(body).encode() if body is not None else None + req = urllib.request.Request( + url, data=data, method=method, + headers={"Content-Type": "application/json"}, + ) + with urllib.request.urlopen(req, timeout=timeout) as r: + raw = r.read().decode() + return json.loads(raw) if raw else {} + + +def loaded_ollama_models(base=None, timeout=3): + """Names of models Ollama currently holds in memory (GET /api/ps).""" + base = base or config.OLLAMA_URL + try: + data = _http("GET", f"{base}/api/ps", timeout=timeout) + return [m["name"] for m in data.get("models", []) if m.get("name")] + except Exception as e: + log.info("ollama_ps_failed", err=str(e)) + return [] + + +def free_ollama_vram(base=None, wait_s=6.0): + """Ask Ollama to unload its loaded models, then wait (briefly) for the VRAM + to actually free. Returns the list of models it tried to unload. No-op when + OLLAMA_FREE_BEFORE_STT is disabled or nothing is loaded. Never raises.""" + if not config.OLLAMA_FREE_BEFORE_STT: + return [] + base = base or config.OLLAMA_URL + models = loaded_ollama_models(base) + if not models: + return [] + # keep_alive:0 tells Ollama to drop the model from memory immediately. + for name in models: + try: + _http("POST", f"{base}/api/generate", + {"model": name, "keep_alive": 0}, timeout=8) + except Exception as e: + log.info("ollama_unload_failed", model=name, err=str(e)) + # Confirm the card is actually clear before we hand it to Whisper. + deadline = time.monotonic() + wait_s + while time.monotonic() < deadline: + if not loaded_ollama_models(base): + break + time.sleep(0.3) + log.info("ollama_vram_freed", models=models) + return models diff --git a/workers/void_workers/model.py b/workers/void_workers/model.py index 727bf60..758f0fd 100644 --- a/workers/void_workers/model.py +++ b/workers/void_workers/model.py @@ -32,6 +32,12 @@ def whisper_model(): # another process sharing the card). HA portability + a shared GPU # mean this must degrade gracefully, never hard-fail a transcription. if cuda_available(): + # Make room on the shared GPU first (best-effort; never raises). + try: + from . import gpu + gpu.free_ollama_vram() + except Exception as e: + log.info("ollama_free_skipped", err=str(e)) try: _whisper_model = _load_whisper("cuda", "float16") except Exception as e: