diff --git a/workers/tests/test_model.py b/workers/tests/test_model.py new file mode 100644 index 0000000..250a4ae --- /dev/null +++ b/workers/tests/test_model.py @@ -0,0 +1,23 @@ +from unittest.mock import patch, MagicMock +from void_workers import model + + +def test_model_returns_singleton(monkeypatch): + m = MagicMock() + monkeypatch.setattr(model, "_whisper_model", None) + with patch("void_workers.model.cuda_available", return_value=False): + with patch("faster_whisper.WhisperModel", return_value=m): + a = model.whisper_model() + b = model.whisper_model() + assert a is b + + +def test_transcribe_returns_joined_segments(monkeypatch): + seg1 = MagicMock(text=" Hello world ") + seg2 = MagicMock(text=" second line") + fake_model = MagicMock() + fake_model.transcribe.return_value = ([seg1, seg2], MagicMock()) + monkeypatch.setattr(model, "_whisper_model", fake_model) + out = model.whisper_transcribe("/tmp/whatever.opus") + assert "Hello world" in out + assert "second line" in out diff --git a/workers/void_workers/model.py b/workers/void_workers/model.py new file mode 100644 index 0000000..0212671 --- /dev/null +++ b/workers/void_workers/model.py @@ -0,0 +1,34 @@ +import os +from .log import log + +_whisper_model = None + + +def cuda_available(): + try: + import ctranslate2 + return ctranslate2.get_cuda_device_count() > 0 + except Exception as e: + log.info("ctranslate2_cuda_probe_failed", err=str(e)) + return False + + +def whisper_model(): + global _whisper_model + if _whisper_model is None: + from faster_whisper import WhisperModel + name = os.environ.get("WHISPER_MODEL", "small.en") + cache = os.environ.get("WHISPER_CACHE", "/var/lib/void/whisper-models") + device = "cuda" if cuda_available() else "cpu" + compute_type = "float16" if device == "cuda" else "int8" + log.info("whisper_loading", model=name, device=device, + compute_type=compute_type, cache=cache) + _whisper_model = WhisperModel( + name, device=device, compute_type=compute_type, download_root=cache + ) + return _whisper_model + + +def whisper_transcribe(audio_path): + segments, _info = whisper_model().transcribe(audio_path, vad_filter=True) + return "\n".join(s.text.strip() for s in segments).strip()