feat(workers): whisper loader with CUDA detect + CPU fallback
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
23
workers/tests/test_model.py
Normal file
23
workers/tests/test_model.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
from void_workers import model
|
||||
|
||||
|
||||
def test_model_returns_singleton(monkeypatch):
|
||||
m = MagicMock()
|
||||
monkeypatch.setattr(model, "_whisper_model", None)
|
||||
with patch("void_workers.model.cuda_available", return_value=False):
|
||||
with patch("faster_whisper.WhisperModel", return_value=m):
|
||||
a = model.whisper_model()
|
||||
b = model.whisper_model()
|
||||
assert a is b
|
||||
|
||||
|
||||
def test_transcribe_returns_joined_segments(monkeypatch):
|
||||
seg1 = MagicMock(text=" Hello world ")
|
||||
seg2 = MagicMock(text=" second line")
|
||||
fake_model = MagicMock()
|
||||
fake_model.transcribe.return_value = ([seg1, seg2], MagicMock())
|
||||
monkeypatch.setattr(model, "_whisper_model", fake_model)
|
||||
out = model.whisper_transcribe("/tmp/whatever.opus")
|
||||
assert "Hello world" in out
|
||||
assert "second line" in out
|
||||
34
workers/void_workers/model.py
Normal file
34
workers/void_workers/model.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import os
|
||||
from .log import log
|
||||
|
||||
_whisper_model = None
|
||||
|
||||
|
||||
def cuda_available():
|
||||
try:
|
||||
import ctranslate2
|
||||
return ctranslate2.get_cuda_device_count() > 0
|
||||
except Exception as e:
|
||||
log.info("ctranslate2_cuda_probe_failed", err=str(e))
|
||||
return False
|
||||
|
||||
|
||||
def whisper_model():
|
||||
global _whisper_model
|
||||
if _whisper_model is None:
|
||||
from faster_whisper import WhisperModel
|
||||
name = os.environ.get("WHISPER_MODEL", "small.en")
|
||||
cache = os.environ.get("WHISPER_CACHE", "/var/lib/void/whisper-models")
|
||||
device = "cuda" if cuda_available() else "cpu"
|
||||
compute_type = "float16" if device == "cuda" else "int8"
|
||||
log.info("whisper_loading", model=name, device=device,
|
||||
compute_type=compute_type, cache=cache)
|
||||
_whisper_model = WhisperModel(
|
||||
name, device=device, compute_type=compute_type, download_root=cache
|
||||
)
|
||||
return _whisper_model
|
||||
|
||||
|
||||
def whisper_transcribe(audio_path):
|
||||
segments, _info = whisper_model().transcribe(audio_path, vad_filter=True)
|
||||
return "\n".join(s.text.strip() for s in segments).strip()
|
||||
Reference in New Issue
Block a user