feat(workers): free Ollama VRAM before loading Whisper on the GPU

Whisper (CT 311) and Ollama (CT 102) share one A2000. Before loading
Whisper on CUDA, ask Ollama to unload its models (GET /api/ps then POST
/api/generate keep_alive:0) and wait for the card to clear, so the GPU
load has headroom. Best-effort and stdlib-only; Ollama reloads
cooperatively, and the existing CUDA->CPU fallback covers any failure.
Toggle via OLLAMA_FREE_BEFORE_STT; endpoint via OLLAMA_URL.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
root
2026-06-05 21:12:05 +10:00
parent c2569cad76
commit a9191cee00
5 changed files with 145 additions and 6 deletions

47
workers/tests/test_gpu.py Normal file
View File

@@ -0,0 +1,47 @@
from unittest.mock import patch, call
from void_workers import gpu, config
def test_free_unloads_each_loaded_model(monkeypatch):
monkeypatch.setattr(config, "OLLAMA_FREE_BEFORE_STT", True)
calls = []
def fake_http(method, url, body=None, timeout=5):
calls.append((method, url, body))
if url.endswith("/api/ps"):
# loaded first, then empty after the unloads (confirm-poll)
return {"models": [{"name": "llama3.1:8b"}]} if len([c for c in calls if c[1].endswith("/api/ps")]) == 1 else {"models": []}
return {}
with patch("void_workers.gpu._http", side_effect=fake_http):
freed = gpu.free_ollama_vram(base="http://x:11434")
assert freed == ["llama3.1:8b"]
# an unload POST with keep_alive:0 was issued for the loaded model
assert (
"POST",
"http://x:11434/api/generate",
{"model": "llama3.1:8b", "keep_alive": 0},
) in calls
def test_free_is_noop_when_disabled(monkeypatch):
monkeypatch.setattr(config, "OLLAMA_FREE_BEFORE_STT", False)
with patch("void_workers.gpu._http") as h:
assert gpu.free_ollama_vram(base="http://x:11434") == []
h.assert_not_called()
def test_free_is_noop_when_nothing_loaded(monkeypatch):
monkeypatch.setattr(config, "OLLAMA_FREE_BEFORE_STT", True)
with patch("void_workers.gpu._http", return_value={"models": []}) as h:
assert gpu.free_ollama_vram(base="http://x:11434") == []
# only the /api/ps probe, no unload POST
assert all(c.args[0] == "GET" for c in h.call_args_list)
def test_free_never_raises_when_ollama_unreachable(monkeypatch):
monkeypatch.setattr(config, "OLLAMA_FREE_BEFORE_STT", True)
with patch("void_workers.gpu._http", side_effect=OSError("connection refused")):
# ps fails -> [] -> no unload -> returns [] without propagating
assert gpu.free_ollama_vram(base="http://x:11434") == []

View File

@@ -15,10 +15,24 @@ def test_model_returns_singleton(monkeypatch):
def test_uses_gpu_when_available(monkeypatch):
monkeypatch.setattr(model, "_whisper_model", None)
with patch("void_workers.model.cuda_available", return_value=True):
with patch("faster_whisper.WhisperModel", return_value=MagicMock()) as WM:
model.whisper_model()
assert WM.call_args.kwargs["device"] == "cuda"
assert WM.call_args.kwargs["compute_type"] == "float16"
with patch("void_workers.gpu.free_ollama_vram", return_value=[]):
with patch("faster_whisper.WhisperModel", return_value=MagicMock()) as WM:
model.whisper_model()
assert WM.call_args.kwargs["device"] == "cuda"
assert WM.call_args.kwargs["compute_type"] == "float16"
def test_frees_ollama_before_gpu_load(monkeypatch):
# Ollama VRAM must be freed BEFORE the cuda model is constructed.
monkeypatch.setattr(model, "_whisper_model", None)
order = []
with patch("void_workers.model.cuda_available", return_value=True):
with patch("void_workers.gpu.free_ollama_vram",
side_effect=lambda *a, **k: order.append("free")):
with patch("faster_whisper.WhisperModel",
side_effect=lambda *a, **k: order.append("load") or MagicMock()):
model.whisper_model()
assert order == ["free", "load"]
def test_falls_back_to_cpu_when_cuda_load_fails(monkeypatch):
@@ -33,8 +47,9 @@ def test_falls_back_to_cpu_when_cuda_load_fails(monkeypatch):
return cpu_model
with patch("void_workers.model.cuda_available", return_value=True):
with patch("faster_whisper.WhisperModel", side_effect=fake_ctor):
got = model.whisper_model()
with patch("void_workers.gpu.free_ollama_vram", return_value=[]):
with patch("faster_whisper.WhisperModel", side_effect=fake_ctor):
got = model.whisper_model()
assert got is cpu_model