mirror of
https://github.com/home-assistant/core.git
synced 2025-11-08 18:39:30 +00:00
208 lines
6.5 KiB
Python
208 lines
6.5 KiB
Python
"""Support for the ElevenLabs speech-to-text service."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncIterable
|
|
from io import BytesIO
|
|
import logging
|
|
|
|
from elevenlabs import AsyncElevenLabs
|
|
from elevenlabs.core import ApiError
|
|
from elevenlabs.types import Model
|
|
|
|
from homeassistant.components import stt
|
|
from homeassistant.components.stt import (
|
|
AudioBitRates,
|
|
AudioChannels,
|
|
AudioCodecs,
|
|
AudioFormats,
|
|
AudioSampleRates,
|
|
SpeechMetadata,
|
|
SpeechResultState,
|
|
SpeechToTextEntity,
|
|
)
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
|
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
|
|
|
from . import ElevenLabsConfigEntry
|
|
from .const import (
|
|
CONF_STT_AUTO_LANGUAGE,
|
|
DEFAULT_STT_AUTO_LANGUAGE,
|
|
DOMAIN,
|
|
STT_LANGUAGES,
|
|
)
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
PARALLEL_UPDATES = 10
|
|
|
|
|
|
async def async_setup_entry(
|
|
hass: HomeAssistant,
|
|
config_entry: ElevenLabsConfigEntry,
|
|
async_add_entities: AddConfigEntryEntitiesCallback,
|
|
) -> None:
|
|
"""Set up ElevenLabs stt platform via config entry."""
|
|
client = config_entry.runtime_data.client
|
|
auto_detect = config_entry.options.get(
|
|
CONF_STT_AUTO_LANGUAGE, DEFAULT_STT_AUTO_LANGUAGE
|
|
)
|
|
|
|
async_add_entities(
|
|
[
|
|
ElevenLabsSTTEntity(
|
|
client,
|
|
config_entry.runtime_data.model,
|
|
config_entry.runtime_data.stt_model,
|
|
config_entry.entry_id,
|
|
auto_detect_language=auto_detect,
|
|
)
|
|
]
|
|
)
|
|
|
|
|
|
class ElevenLabsSTTEntity(SpeechToTextEntity):
|
|
"""The ElevenLabs STT API entity."""
|
|
|
|
_attr_has_entity_name = True
|
|
_attr_translation_key = "elevenlabs_stt"
|
|
|
|
def __init__(
|
|
self,
|
|
client: AsyncElevenLabs,
|
|
model: Model,
|
|
stt_model: str,
|
|
entry_id: str,
|
|
auto_detect_language: bool = False,
|
|
) -> None:
|
|
"""Init ElevenLabs TTS service."""
|
|
self._client = client
|
|
self._auto_detect_language = auto_detect_language
|
|
self._stt_model = stt_model
|
|
|
|
# Entity attributes
|
|
self._attr_unique_id = entry_id
|
|
self._attr_device_info = DeviceInfo(
|
|
identifiers={(DOMAIN, entry_id)},
|
|
manufacturer="ElevenLabs",
|
|
model=model.name,
|
|
name="ElevenLabs",
|
|
entry_type=DeviceEntryType.SERVICE,
|
|
)
|
|
|
|
@property
|
|
def supported_languages(self) -> list[str]:
|
|
"""Return a list of supported languages."""
|
|
return STT_LANGUAGES
|
|
|
|
@property
|
|
def supported_formats(self) -> list[AudioFormats]:
|
|
"""Return a list of supported formats."""
|
|
return [AudioFormats.WAV, AudioFormats.OGG]
|
|
|
|
@property
|
|
def supported_codecs(self) -> list[AudioCodecs]:
|
|
"""Return a list of supported codecs."""
|
|
return [AudioCodecs.PCM, AudioCodecs.OPUS]
|
|
|
|
@property
|
|
def supported_bit_rates(self) -> list[AudioBitRates]:
|
|
"""Return a list of supported bit rates."""
|
|
return [AudioBitRates.BITRATE_16]
|
|
|
|
@property
|
|
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
|
"""Return a list of supported sample rates."""
|
|
return [AudioSampleRates.SAMPLERATE_16000]
|
|
|
|
@property
|
|
def supported_channels(self) -> list[AudioChannels]:
|
|
"""Return a list of supported channels."""
|
|
return [
|
|
AudioChannels.CHANNEL_MONO,
|
|
AudioChannels.CHANNEL_STEREO,
|
|
]
|
|
|
|
async def async_process_audio_stream(
|
|
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
|
) -> stt.SpeechResult:
|
|
"""Process an audio stream to STT service."""
|
|
_LOGGER.debug(
|
|
"Processing audio stream for STT: model=%s, language=%s, format=%s, codec=%s, sample_rate=%s, channels=%s, bit_rate=%s",
|
|
self._stt_model,
|
|
metadata.language,
|
|
metadata.format,
|
|
metadata.codec,
|
|
metadata.sample_rate,
|
|
metadata.channel,
|
|
metadata.bit_rate,
|
|
)
|
|
|
|
if self._auto_detect_language:
|
|
lang_code = None
|
|
else:
|
|
language = metadata.language
|
|
if language.lower() not in [lang.lower() for lang in STT_LANGUAGES]:
|
|
_LOGGER.warning("Unsupported language: %s", language)
|
|
return stt.SpeechResult(None, SpeechResultState.ERROR)
|
|
lang_code = language.split("-")[0]
|
|
|
|
raw_pcm_compatible = (
|
|
metadata.codec == AudioCodecs.PCM
|
|
and metadata.sample_rate == AudioSampleRates.SAMPLERATE_16000
|
|
and metadata.channel == AudioChannels.CHANNEL_MONO
|
|
and metadata.bit_rate == AudioBitRates.BITRATE_16
|
|
)
|
|
if raw_pcm_compatible:
|
|
file_format = "pcm_s16le_16"
|
|
elif metadata.codec == AudioCodecs.PCM:
|
|
_LOGGER.warning("PCM input does not meet expected raw format requirements")
|
|
return stt.SpeechResult(None, SpeechResultState.ERROR)
|
|
else:
|
|
file_format = "other"
|
|
|
|
audio = b""
|
|
async for chunk in stream:
|
|
audio += chunk
|
|
|
|
_LOGGER.debug("Finished reading audio stream, total size: %d bytes", len(audio))
|
|
if not audio:
|
|
_LOGGER.warning("No audio received in stream")
|
|
return stt.SpeechResult(None, SpeechResultState.ERROR)
|
|
|
|
lang_display = lang_code if lang_code else "auto-detected"
|
|
|
|
_LOGGER.debug(
|
|
"Transcribing audio (%s), format: %s, size: %d bytes",
|
|
lang_display,
|
|
file_format,
|
|
len(audio),
|
|
)
|
|
|
|
try:
|
|
response = await self._client.speech_to_text.convert(
|
|
file=BytesIO(audio),
|
|
file_format=file_format,
|
|
model_id=self._stt_model,
|
|
language_code=lang_code,
|
|
tag_audio_events=False,
|
|
num_speakers=1,
|
|
diarize=False,
|
|
)
|
|
except ApiError as exc:
|
|
_LOGGER.error("Error during processing of STT request: %s", exc)
|
|
return stt.SpeechResult(None, SpeechResultState.ERROR)
|
|
|
|
text = response.text or ""
|
|
detected_lang_code = response.language_code or "?"
|
|
detected_lang_prob = response.language_probability or "?"
|
|
|
|
_LOGGER.debug(
|
|
"Transcribed text is in language %s (probability %s): %s",
|
|
detected_lang_code,
|
|
detected_lang_prob,
|
|
text,
|
|
)
|
|
|
|
return stt.SpeechResult(text, SpeechResultState.SUCCESS)
|