mirror of
https://github.com/home-assistant/core.git
synced 2025-07-10 06:47:09 +00:00
Wyoming tts (#91712)
* Add tts entity * Add tts entity and tests * Re-add name to TextToSpeechEntity * Fix linting * Fix ruff linting * Support voice attr (unused) * Remove async_get_text_to_speech_entity * Move name property to Wyoming TTS entity * Fix id --------- Co-authored-by: Martin Hjelmare <marhje52@gmail.com> Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
f4df0ca50a
commit
b6f2b29a99
@ -52,10 +52,14 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||||||
|
|
||||||
# ASR = automated speech recognition (STT)
|
# ASR = automated speech recognition (STT)
|
||||||
asr_installed = [asr for asr in service.info.asr if asr.installed]
|
asr_installed = [asr for asr in service.info.asr if asr.installed]
|
||||||
if not asr_installed:
|
tts_installed = [tts for tts in service.info.tts if tts.installed]
|
||||||
return self.async_abort(reason="no_services")
|
|
||||||
|
|
||||||
name = asr_installed[0].name
|
if asr_installed:
|
||||||
|
name = asr_installed[0].name
|
||||||
|
elif tts_installed:
|
||||||
|
name = tts_installed[0].name
|
||||||
|
else:
|
||||||
|
return self.async_abort(reason="no_services")
|
||||||
|
|
||||||
return self.async_create_entry(title=name, data=user_input)
|
return self.async_create_entry(title=name, data=user_input)
|
||||||
|
|
||||||
|
@ -25,8 +25,10 @@ class WyomingService:
|
|||||||
self.port = port
|
self.port = port
|
||||||
self.info = info
|
self.info = info
|
||||||
platforms = []
|
platforms = []
|
||||||
if info.asr:
|
if any(asr.installed for asr in info.asr):
|
||||||
platforms.append(Platform.STT)
|
platforms.append(Platform.STT)
|
||||||
|
if any(tts.installed for tts in info.tts):
|
||||||
|
platforms.append(Platform.TTS)
|
||||||
self.platforms = platforms
|
self.platforms = platforms
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -39,14 +41,20 @@ class WyomingService:
|
|||||||
return cls(host, port, info)
|
return cls(host, port, info)
|
||||||
|
|
||||||
|
|
||||||
async def load_wyoming_info(host: str, port: int) -> Info | None:
|
async def load_wyoming_info(
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
retries: int = _INFO_RETRIES,
|
||||||
|
retry_wait: float = _INFO_RETRY_WAIT,
|
||||||
|
timeout: float = _INFO_TIMEOUT,
|
||||||
|
) -> Info | None:
|
||||||
"""Load info from Wyoming server."""
|
"""Load info from Wyoming server."""
|
||||||
wyoming_info: Info | None = None
|
wyoming_info: Info | None = None
|
||||||
|
|
||||||
for _ in range(_INFO_RETRIES):
|
for _ in range(retries + 1):
|
||||||
try:
|
try:
|
||||||
async with AsyncTcpClient(host, port) as client:
|
async with AsyncTcpClient(host, port) as client:
|
||||||
with async_timeout.timeout(_INFO_TIMEOUT):
|
with async_timeout.timeout(timeout):
|
||||||
# Describe -> Info
|
# Describe -> Info
|
||||||
await client.write_event(Describe().event())
|
await client.write_event(Describe().event())
|
||||||
while True:
|
while True:
|
||||||
@ -58,9 +66,12 @@ async def load_wyoming_info(host: str, port: int) -> Info | None:
|
|||||||
|
|
||||||
if Info.is_type(event.type):
|
if Info.is_type(event.type):
|
||||||
wyoming_info = Info.from_event(event)
|
wyoming_info = Info.from_event(event)
|
||||||
break
|
break # while
|
||||||
|
|
||||||
|
if wyoming_info is not None:
|
||||||
|
break # for
|
||||||
except (asyncio.TimeoutError, OSError, WyomingError):
|
except (asyncio.TimeoutError, OSError, WyomingError):
|
||||||
# Sleep and try again
|
# Sleep and try again
|
||||||
await asyncio.sleep(_INFO_RETRY_WAIT)
|
await asyncio.sleep(retry_wait)
|
||||||
|
|
||||||
return wyoming_info
|
return wyoming_info
|
||||||
|
161
homeassistant/components/wyoming/tts.py
Normal file
161
homeassistant/components/wyoming/tts.py
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
"""Support for Wyoming text to speech services."""
|
||||||
|
from collections import defaultdict
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import wave
|
||||||
|
|
||||||
|
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
|
||||||
|
from wyoming.client import AsyncTcpClient
|
||||||
|
from wyoming.tts import Synthesize
|
||||||
|
|
||||||
|
from homeassistant.components import tts
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .data import WyomingService
|
||||||
|
from .error import WyomingError
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up Wyoming speech to text."""
|
||||||
|
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
|
||||||
|
async_add_entities(
|
||||||
|
[
|
||||||
|
WyomingTtsProvider(config_entry, service),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WyomingTtsProvider(tts.TextToSpeechEntity):
|
||||||
|
"""Wyoming text to speech provider."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
service: WyomingService,
|
||||||
|
) -> None:
|
||||||
|
"""Set up provider."""
|
||||||
|
self.service = service
|
||||||
|
self._tts_service = next(tts for tts in service.info.tts if tts.installed)
|
||||||
|
|
||||||
|
voice_languages: set[str] = set()
|
||||||
|
self._voices: dict[str, list[tts.Voice]] = defaultdict(list)
|
||||||
|
for voice in self._tts_service.voices:
|
||||||
|
if not voice.installed:
|
||||||
|
continue
|
||||||
|
|
||||||
|
voice_languages.update(voice.languages)
|
||||||
|
for language in voice.languages:
|
||||||
|
self._voices[language].append(
|
||||||
|
tts.Voice(
|
||||||
|
voice_id=voice.name,
|
||||||
|
name=voice.name,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self._supported_languages: list[str] = list(voice_languages)
|
||||||
|
|
||||||
|
self._attr_name = self._tts_service.name
|
||||||
|
self._attr_unique_id = f"{config_entry.entry_id}-tts"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str | None:
|
||||||
|
"""Return the name of the provider entity."""
|
||||||
|
# Only one entity is allowed per platform for now.
|
||||||
|
return self._tts_service.name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_language(self):
|
||||||
|
"""Return default language."""
|
||||||
|
if not self._supported_languages:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self._supported_languages[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_languages(self):
|
||||||
|
"""Return list of supported languages."""
|
||||||
|
return self._supported_languages
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supported_options(self):
|
||||||
|
"""Return list of supported options like voice, emotion."""
|
||||||
|
return [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_options(self):
|
||||||
|
"""Return a dict include default options."""
|
||||||
|
return {tts.ATTR_AUDIO_OUTPUT: "wav"}
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None:
|
||||||
|
"""Return a list of supported voices for a language."""
|
||||||
|
return self._voices.get(language)
|
||||||
|
|
||||||
|
async def async_get_tts_audio(self, message, language, options=None):
|
||||||
|
"""Load TTS from UNIX socket."""
|
||||||
|
try:
|
||||||
|
async with AsyncTcpClient(self.service.host, self.service.port) as client:
|
||||||
|
await client.write_event(Synthesize(message).event())
|
||||||
|
|
||||||
|
with io.BytesIO() as wav_io:
|
||||||
|
wav_writer: wave.Wave_write | None = None
|
||||||
|
while True:
|
||||||
|
event = await client.read_event()
|
||||||
|
if event is None:
|
||||||
|
_LOGGER.debug("Connection lost")
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
if AudioStop.is_type(event.type):
|
||||||
|
break
|
||||||
|
|
||||||
|
if AudioChunk.is_type(event.type):
|
||||||
|
chunk = AudioChunk.from_event(event)
|
||||||
|
if wav_writer is None:
|
||||||
|
wav_writer = wave.open(wav_io, "wb")
|
||||||
|
wav_writer.setframerate(chunk.rate)
|
||||||
|
wav_writer.setsampwidth(chunk.width)
|
||||||
|
wav_writer.setnchannels(chunk.channels)
|
||||||
|
|
||||||
|
wav_writer.writeframes(chunk.audio)
|
||||||
|
|
||||||
|
if wav_writer is not None:
|
||||||
|
wav_writer.close()
|
||||||
|
|
||||||
|
data = wav_io.getvalue()
|
||||||
|
|
||||||
|
except (OSError, WyomingError):
|
||||||
|
return (None, None)
|
||||||
|
|
||||||
|
if (options is None) or (options[tts.ATTR_AUDIO_OUTPUT] == "wav"):
|
||||||
|
return ("wav", data)
|
||||||
|
|
||||||
|
# Raw output (convert to 16Khz, 16-bit mono)
|
||||||
|
with io.BytesIO(data) as wav_io:
|
||||||
|
wav_reader: wave.Wave_read = wave.open(wav_io, "rb")
|
||||||
|
raw_data = (
|
||||||
|
AudioChunkConverter(
|
||||||
|
rate=16000,
|
||||||
|
width=2,
|
||||||
|
channels=1,
|
||||||
|
)
|
||||||
|
.convert(
|
||||||
|
AudioChunk(
|
||||||
|
audio=wav_reader.readframes(wav_reader.getnframes()),
|
||||||
|
rate=wav_reader.getframerate(),
|
||||||
|
width=wav_reader.getsampwidth(),
|
||||||
|
channels=wav_reader.getnchannels(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.audio
|
||||||
|
)
|
||||||
|
|
||||||
|
return ("raw", raw_data)
|
@ -1,5 +1,5 @@
|
|||||||
"""Tests for the Wyoming integration."""
|
"""Tests for the Wyoming integration."""
|
||||||
from wyoming.info import AsrModel, AsrProgram, Attribution, Info
|
from wyoming.info import AsrModel, AsrProgram, Attribution, Info, TtsProgram, TtsVoice
|
||||||
|
|
||||||
TEST_ATTR = Attribution(name="Test", url="http://www.test.com")
|
TEST_ATTR = Attribution(name="Test", url="http://www.test.com")
|
||||||
STT_INFO = Info(
|
STT_INFO = Info(
|
||||||
@ -19,4 +19,53 @@ STT_INFO = Info(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
TTS_INFO = Info(
|
||||||
|
tts=[
|
||||||
|
TtsProgram(
|
||||||
|
name="Test TTS",
|
||||||
|
installed=True,
|
||||||
|
attribution=TEST_ATTR,
|
||||||
|
voices=[
|
||||||
|
TtsVoice(
|
||||||
|
name="Test Voice",
|
||||||
|
installed=True,
|
||||||
|
attribution=TEST_ATTR,
|
||||||
|
languages=["en-US"],
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
EMPTY_INFO = Info()
|
EMPTY_INFO = Info()
|
||||||
|
|
||||||
|
|
||||||
|
class MockAsyncTcpClient:
|
||||||
|
"""Mock AsyncTcpClient."""
|
||||||
|
|
||||||
|
def __init__(self, responses) -> None:
|
||||||
|
"""Initialize."""
|
||||||
|
self.host = None
|
||||||
|
self.port = None
|
||||||
|
self.written = []
|
||||||
|
self.responses = responses
|
||||||
|
|
||||||
|
async def write_event(self, event):
|
||||||
|
"""Send."""
|
||||||
|
self.written.append(event)
|
||||||
|
|
||||||
|
async def read_event(self):
|
||||||
|
"""Receive."""
|
||||||
|
return self.responses.pop(0)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
"""Enter."""
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
"""Exit."""
|
||||||
|
|
||||||
|
def __call__(self, host, port):
|
||||||
|
"""Call."""
|
||||||
|
self.host = host
|
||||||
|
self.port = port
|
||||||
|
return self
|
||||||
|
@ -7,7 +7,7 @@ import pytest
|
|||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
from . import STT_INFO
|
from . import STT_INFO, TTS_INFO
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
@ -22,7 +22,7 @@ def mock_setup_entry() -> Generator[AsyncMock, None, None]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def config_entry(hass: HomeAssistant) -> ConfigEntry:
|
def stt_config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||||
"""Create a config entry."""
|
"""Create a config entry."""
|
||||||
entry = MockConfigEntry(
|
entry = MockConfigEntry(
|
||||||
domain="wyoming",
|
domain="wyoming",
|
||||||
@ -37,10 +37,35 @@ def config_entry(hass: HomeAssistant) -> ConfigEntry:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def init_wyoming_stt(hass: HomeAssistant, config_entry: ConfigEntry):
|
def tts_config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||||
"""Initialize Wyoming."""
|
"""Create a config entry."""
|
||||||
|
entry = MockConfigEntry(
|
||||||
|
domain="wyoming",
|
||||||
|
data={
|
||||||
|
"host": "1.2.3.4",
|
||||||
|
"port": 1234,
|
||||||
|
},
|
||||||
|
title="Test TTS",
|
||||||
|
)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def init_wyoming_stt(hass: HomeAssistant, stt_config_entry: ConfigEntry):
|
||||||
|
"""Initialize Wyoming STT."""
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
return_value=STT_INFO,
|
return_value=STT_INFO,
|
||||||
):
|
):
|
||||||
await hass.config_entries.async_setup(config_entry.entry_id)
|
await hass.config_entries.async_setup(stt_config_entry.entry_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def init_wyoming_tts(hass: HomeAssistant, tts_config_entry: ConfigEntry):
|
||||||
|
"""Initialize Wyoming TTS."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=TTS_INFO,
|
||||||
|
):
|
||||||
|
await hass.config_entries.async_setup(tts_config_entry.entry_id)
|
||||||
|
11
tests/components/wyoming/snapshots/test_data.ambr
Normal file
11
tests/components/wyoming/snapshots/test_data.ambr
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_load_info
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
}),
|
||||||
|
'payload': None,
|
||||||
|
'type': 'describe',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
23
tests/components/wyoming/snapshots/test_tts.ambr
Normal file
23
tests/components/wyoming/snapshots/test_tts.ambr
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_get_tts_audio
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'text': 'Hello world',
|
||||||
|
}),
|
||||||
|
'payload': None,
|
||||||
|
'type': 'synthesize',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_get_tts_audio_raw
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'text': 'Hello world',
|
||||||
|
}),
|
||||||
|
'payload': None,
|
||||||
|
'type': 'synthesize',
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
@ -10,7 +10,7 @@ from homeassistant.components.wyoming.const import DOMAIN
|
|||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.data_entry_flow import FlowResultType
|
from homeassistant.data_entry_flow import FlowResultType
|
||||||
|
|
||||||
from . import EMPTY_INFO, STT_INFO
|
from . import EMPTY_INFO, STT_INFO, TTS_INFO
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
@ -26,7 +26,7 @@ ADDON_DISCOVERY = HassioServiceInfo(
|
|||||||
pytestmark = pytest.mark.usefixtures("mock_setup_entry")
|
pytestmark = pytest.mark.usefixtures("mock_setup_entry")
|
||||||
|
|
||||||
|
|
||||||
async def test_form(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
|
async def test_form_stt(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
|
||||||
"""Test we get the form."""
|
"""Test we get the form."""
|
||||||
result = await hass.config_entries.flow.async_init(
|
result = await hass.config_entries.flow.async_init(
|
||||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
@ -56,6 +56,36 @@ async def test_form(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
|
|||||||
assert len(mock_setup_entry.mock_calls) == 1
|
assert len(mock_setup_entry.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_form_tts(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
|
||||||
|
"""Test we get the form."""
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["errors"] is None
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=TTS_INFO,
|
||||||
|
):
|
||||||
|
result2 = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{
|
||||||
|
"host": "1.1.1.1",
|
||||||
|
"port": 1234,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert result2["type"] == FlowResultType.CREATE_ENTRY
|
||||||
|
assert result2["title"] == "Test TTS"
|
||||||
|
assert result2["data"] == {
|
||||||
|
"host": "1.1.1.1",
|
||||||
|
"port": 1234,
|
||||||
|
}
|
||||||
|
assert len(mock_setup_entry.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_form_cannot_connect(hass: HomeAssistant) -> None:
|
async def test_form_cannot_connect(hass: HomeAssistant) -> None:
|
||||||
"""Test we handle cannot connect error."""
|
"""Test we handle cannot connect error."""
|
||||||
result = await hass.config_entries.flow.async_init(
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
40
tests/components/wyoming/test_data.py
Normal file
40
tests/components/wyoming/test_data.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
"""Test tts."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from homeassistant.components.wyoming.data import load_wyoming_info
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
from . import STT_INFO, MockAsyncTcpClient
|
||||||
|
|
||||||
|
|
||||||
|
async def test_load_info(hass: HomeAssistant, snapshot) -> None:
|
||||||
|
"""Test loading info."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||||
|
MockAsyncTcpClient([STT_INFO.event()]),
|
||||||
|
) as mock_client:
|
||||||
|
info = await load_wyoming_info("localhost", 1234)
|
||||||
|
|
||||||
|
assert info == STT_INFO
|
||||||
|
assert mock_client.written == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_load_info_oserror(hass: HomeAssistant) -> None:
|
||||||
|
"""Test loading info and error raising."""
|
||||||
|
mock_client = MockAsyncTcpClient([STT_INFO.event()])
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||||
|
mock_client,
|
||||||
|
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
|
||||||
|
info = await load_wyoming_info(
|
||||||
|
"localhost",
|
||||||
|
1234,
|
||||||
|
retries=0,
|
||||||
|
retry_wait=0,
|
||||||
|
timeout=0.001,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert info is None
|
@ -5,17 +5,19 @@ from homeassistant.config_entries import ConfigEntry
|
|||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
|
||||||
async def test_cannot_connect(hass: HomeAssistant, config_entry: ConfigEntry) -> None:
|
async def test_cannot_connect(
|
||||||
|
hass: HomeAssistant, stt_config_entry: ConfigEntry
|
||||||
|
) -> None:
|
||||||
"""Test we handle cannot connect error."""
|
"""Test we handle cannot connect error."""
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
return_value=None,
|
return_value=None,
|
||||||
):
|
):
|
||||||
assert not await hass.config_entries.async_setup(config_entry.entry_id)
|
assert not await hass.config_entries.async_setup(stt_config_entry.entry_id)
|
||||||
|
|
||||||
|
|
||||||
async def test_unload(
|
async def test_unload(
|
||||||
hass: HomeAssistant, config_entry: ConfigEntry, init_wyoming_stt
|
hass: HomeAssistant, stt_config_entry: ConfigEntry, init_wyoming_stt
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test unload."""
|
"""Test unload."""
|
||||||
assert await hass.config_entries.async_unload(config_entry.entry_id)
|
assert await hass.config_entries.async_unload(stt_config_entry.entry_id)
|
||||||
|
@ -3,50 +3,21 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from wyoming.event import Event
|
from wyoming.asr import Transcript
|
||||||
|
|
||||||
from homeassistant.components import stt
|
from homeassistant.components import stt
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
from . import MockAsyncTcpClient
|
||||||
class MockAsyncTcpClient:
|
|
||||||
"""Mock AsyncTcpClient."""
|
|
||||||
|
|
||||||
def __init__(self, responses) -> None:
|
|
||||||
"""Initialize."""
|
|
||||||
self.host = None
|
|
||||||
self.port = None
|
|
||||||
self.written = []
|
|
||||||
self.responses = responses
|
|
||||||
|
|
||||||
async def write_event(self, event):
|
|
||||||
"""Send."""
|
|
||||||
self.written.append(event)
|
|
||||||
|
|
||||||
async def read_event(self):
|
|
||||||
"""Receive."""
|
|
||||||
return self.responses.pop(0)
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
|
||||||
"""Enter."""
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
|
||||||
"""Exit."""
|
|
||||||
|
|
||||||
def __call__(self, host, port):
|
|
||||||
"""Call."""
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
|
async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
|
||||||
"""Test streaming audio."""
|
"""Test supported properties."""
|
||||||
state = hass.states.get("stt.wyoming")
|
state = hass.states.get("stt.wyoming")
|
||||||
assert state is not None
|
assert state is not None
|
||||||
|
|
||||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||||
|
assert entity is not None
|
||||||
|
|
||||||
assert entity.supported_languages == ["en-US"]
|
assert entity.supported_languages == ["en-US"]
|
||||||
assert entity.supported_formats == [stt.AudioFormats.WAV]
|
assert entity.supported_formats == [stt.AudioFormats.WAV]
|
||||||
@ -59,6 +30,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
|
|||||||
async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None:
|
async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None:
|
||||||
"""Test streaming audio."""
|
"""Test streaming audio."""
|
||||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||||
|
assert entity is not None
|
||||||
|
|
||||||
async def audio_stream():
|
async def audio_stream():
|
||||||
yield "chunk1"
|
yield "chunk1"
|
||||||
@ -66,7 +38,7 @@ async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot)
|
|||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.wyoming.stt.AsyncTcpClient",
|
"homeassistant.components.wyoming.stt.AsyncTcpClient",
|
||||||
MockAsyncTcpClient([Event(type="transcript", data={"text": "Hello world"})]),
|
MockAsyncTcpClient([Transcript(text="Hello world").event()]),
|
||||||
) as mock_client:
|
) as mock_client:
|
||||||
result = await entity.async_process_audio_stream(None, audio_stream())
|
result = await entity.async_process_audio_stream(None, audio_stream())
|
||||||
|
|
||||||
@ -80,6 +52,7 @@ async def test_streaming_audio_connection_lost(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test streaming audio and losing connection."""
|
"""Test streaming audio and losing connection."""
|
||||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||||
|
assert entity is not None
|
||||||
|
|
||||||
async def audio_stream():
|
async def audio_stream():
|
||||||
yield "chunk1"
|
yield "chunk1"
|
||||||
@ -97,13 +70,12 @@ async def test_streaming_audio_connection_lost(
|
|||||||
async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None:
|
async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None:
|
||||||
"""Test streaming audio and error raising."""
|
"""Test streaming audio and error raising."""
|
||||||
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming")
|
||||||
|
assert entity is not None
|
||||||
|
|
||||||
async def audio_stream():
|
async def audio_stream():
|
||||||
yield "chunk1"
|
yield "chunk1"
|
||||||
|
|
||||||
mock_client = MockAsyncTcpClient(
|
mock_client = MockAsyncTcpClient([Transcript(text="Hello world").event()])
|
||||||
[Event(type="transcript", data={"text": "Hello world"})]
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.wyoming.stt.AsyncTcpClient",
|
"homeassistant.components.wyoming.stt.AsyncTcpClient",
|
||||||
|
143
tests/components/wyoming/test_tts.py
Normal file
143
tests/components/wyoming/test_tts.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
"""Test tts."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import io
|
||||||
|
from unittest.mock import patch
|
||||||
|
import wave
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from wyoming.audio import AudioChunk, AudioStop
|
||||||
|
|
||||||
|
from homeassistant.components import tts
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
from homeassistant.helpers.entity_component import DATA_INSTANCES
|
||||||
|
|
||||||
|
from . import MockAsyncTcpClient
|
||||||
|
|
||||||
|
from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import
|
||||||
|
init_cache_dir_side_effect,
|
||||||
|
mock_get_cache_files,
|
||||||
|
mock_init_cache_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
|
||||||
|
"""Test supported properties."""
|
||||||
|
state = hass.states.get("tts.test_tts")
|
||||||
|
assert state is not None
|
||||||
|
|
||||||
|
entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts")
|
||||||
|
assert entity is not None
|
||||||
|
|
||||||
|
assert entity.supported_languages == ["en-US"]
|
||||||
|
assert entity.supported_options == [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE]
|
||||||
|
voices = entity.async_get_supported_voices("en-US")
|
||||||
|
assert len(voices) == 1
|
||||||
|
assert voices[0].name == "Test Voice"
|
||||||
|
assert voices[0].voice_id == "Test Voice"
|
||||||
|
assert not entity.async_get_supported_voices("de-DE")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None:
|
||||||
|
"""Test get audio."""
|
||||||
|
audio = bytes(100)
|
||||||
|
audio_events = [
|
||||||
|
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||||
|
AudioStop().event(),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||||
|
MockAsyncTcpClient(audio_events),
|
||||||
|
) as mock_client:
|
||||||
|
extension, data = await tts.async_get_media_source_audio(
|
||||||
|
hass,
|
||||||
|
tts.generate_media_source_id(
|
||||||
|
hass, "Hello world", "tts.test_tts", hass.config.language
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert extension == "wav"
|
||||||
|
assert data is not None
|
||||||
|
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||||
|
assert wav_file.getframerate() == 16000
|
||||||
|
assert wav_file.getsampwidth() == 2
|
||||||
|
assert wav_file.getnchannels() == 1
|
||||||
|
assert wav_file.readframes(wav_file.getnframes()) == audio
|
||||||
|
|
||||||
|
assert mock_client.written == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_tts_audio_raw(
|
||||||
|
hass: HomeAssistant, init_wyoming_tts, snapshot
|
||||||
|
) -> None:
|
||||||
|
"""Test get raw audio."""
|
||||||
|
audio = bytes(100)
|
||||||
|
audio_events = [
|
||||||
|
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||||
|
AudioStop().event(),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||||
|
MockAsyncTcpClient(audio_events),
|
||||||
|
) as mock_client:
|
||||||
|
extension, data = await tts.async_get_media_source_audio(
|
||||||
|
hass,
|
||||||
|
tts.generate_media_source_id(
|
||||||
|
hass,
|
||||||
|
"Hello world",
|
||||||
|
"tts.test_tts",
|
||||||
|
hass.config.language,
|
||||||
|
options={tts.ATTR_AUDIO_OUTPUT: "raw"},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert extension == "raw"
|
||||||
|
assert data == audio
|
||||||
|
assert mock_client.written == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_tts_audio_connection_lost(
|
||||||
|
hass: HomeAssistant, init_wyoming_tts
|
||||||
|
) -> None:
|
||||||
|
"""Test streaming audio and losing connection."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||||
|
MockAsyncTcpClient([None]),
|
||||||
|
), pytest.raises(HomeAssistantError):
|
||||||
|
await tts.async_get_media_source_audio(
|
||||||
|
hass,
|
||||||
|
tts.generate_media_source_id(
|
||||||
|
hass, "Hello world", "tts.test_tts", hass.config.language
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_tts_audio_audio_oserror(
|
||||||
|
hass: HomeAssistant, init_wyoming_tts
|
||||||
|
) -> None:
|
||||||
|
"""Test get audio and error raising."""
|
||||||
|
audio = bytes(100)
|
||||||
|
audio_events = [
|
||||||
|
AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(),
|
||||||
|
AudioStop().event(),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_client = MockAsyncTcpClient(audio_events)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.tts.AsyncTcpClient",
|
||||||
|
mock_client,
|
||||||
|
), patch.object(
|
||||||
|
mock_client, "read_event", side_effect=OSError("Boom!")
|
||||||
|
), pytest.raises(
|
||||||
|
HomeAssistantError
|
||||||
|
):
|
||||||
|
await tts.async_get_media_source_audio(
|
||||||
|
hass,
|
||||||
|
tts.generate_media_source_id(
|
||||||
|
hass, "Hello world", "tts.test_tts", hass.config.language
|
||||||
|
),
|
||||||
|
)
|
Loading…
x
Reference in New Issue
Block a user