diff --git a/homeassistant/components/wyoming/config_flow.py b/homeassistant/components/wyoming/config_flow.py index ee7e1d574ac..a8facbf1c30 100644 --- a/homeassistant/components/wyoming/config_flow.py +++ b/homeassistant/components/wyoming/config_flow.py @@ -52,10 +52,14 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): # ASR = automated speech recognition (STT) asr_installed = [asr for asr in service.info.asr if asr.installed] - if not asr_installed: - return self.async_abort(reason="no_services") + tts_installed = [tts for tts in service.info.tts if tts.installed] - name = asr_installed[0].name + if asr_installed: + name = asr_installed[0].name + elif tts_installed: + name = tts_installed[0].name + else: + return self.async_abort(reason="no_services") return self.async_create_entry(title=name, data=user_input) diff --git a/homeassistant/components/wyoming/data.py b/homeassistant/components/wyoming/data.py index f5f869d8e68..3ef93810b6e 100644 --- a/homeassistant/components/wyoming/data.py +++ b/homeassistant/components/wyoming/data.py @@ -25,8 +25,10 @@ class WyomingService: self.port = port self.info = info platforms = [] - if info.asr: + if any(asr.installed for asr in info.asr): platforms.append(Platform.STT) + if any(tts.installed for tts in info.tts): + platforms.append(Platform.TTS) self.platforms = platforms @classmethod @@ -39,14 +41,20 @@ class WyomingService: return cls(host, port, info) -async def load_wyoming_info(host: str, port: int) -> Info | None: +async def load_wyoming_info( + host: str, + port: int, + retries: int = _INFO_RETRIES, + retry_wait: float = _INFO_RETRY_WAIT, + timeout: float = _INFO_TIMEOUT, +) -> Info | None: """Load info from Wyoming server.""" wyoming_info: Info | None = None - for _ in range(_INFO_RETRIES): + for _ in range(retries + 1): try: async with AsyncTcpClient(host, port) as client: - with async_timeout.timeout(_INFO_TIMEOUT): + with async_timeout.timeout(timeout): # Describe -> Info await client.write_event(Describe().event()) while True: @@ -58,9 +66,12 @@ async def load_wyoming_info(host: str, port: int) -> Info | None: if Info.is_type(event.type): wyoming_info = Info.from_event(event) - break + break # while + + if wyoming_info is not None: + break # for except (asyncio.TimeoutError, OSError, WyomingError): # Sleep and try again - await asyncio.sleep(_INFO_RETRY_WAIT) + await asyncio.sleep(retry_wait) return wyoming_info diff --git a/homeassistant/components/wyoming/tts.py b/homeassistant/components/wyoming/tts.py new file mode 100644 index 00000000000..8a6687ea888 --- /dev/null +++ b/homeassistant/components/wyoming/tts.py @@ -0,0 +1,161 @@ +"""Support for Wyoming text to speech services.""" +from collections import defaultdict +import io +import logging +import wave + +from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop +from wyoming.client import AsyncTcpClient +from wyoming.tts import Synthesize + +from homeassistant.components import tts +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from .const import DOMAIN +from .data import WyomingService +from .error import WyomingError + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up Wyoming speech to text.""" + service: WyomingService = hass.data[DOMAIN][config_entry.entry_id] + async_add_entities( + [ + WyomingTtsProvider(config_entry, service), + ] + ) + + +class WyomingTtsProvider(tts.TextToSpeechEntity): + """Wyoming text to speech provider.""" + + def __init__( + self, + config_entry: ConfigEntry, + service: WyomingService, + ) -> None: + """Set up provider.""" + self.service = service + self._tts_service = next(tts for tts in service.info.tts if tts.installed) + + voice_languages: set[str] = set() + self._voices: dict[str, list[tts.Voice]] = defaultdict(list) + for voice in self._tts_service.voices: + if not voice.installed: + continue + + voice_languages.update(voice.languages) + for language in voice.languages: + self._voices[language].append( + tts.Voice( + voice_id=voice.name, + name=voice.name, + ) + ) + + self._supported_languages: list[str] = list(voice_languages) + + self._attr_name = self._tts_service.name + self._attr_unique_id = f"{config_entry.entry_id}-tts" + + @property + def name(self) -> str | None: + """Return the name of the provider entity.""" + # Only one entity is allowed per platform for now. + return self._tts_service.name + + @property + def default_language(self): + """Return default language.""" + if not self._supported_languages: + return None + + return self._supported_languages[0] + + @property + def supported_languages(self): + """Return list of supported languages.""" + return self._supported_languages + + @property + def supported_options(self): + """Return list of supported options like voice, emotion.""" + return [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE] + + @property + def default_options(self): + """Return a dict include default options.""" + return {tts.ATTR_AUDIO_OUTPUT: "wav"} + + @callback + def async_get_supported_voices(self, language: str) -> list[tts.Voice] | None: + """Return a list of supported voices for a language.""" + return self._voices.get(language) + + async def async_get_tts_audio(self, message, language, options=None): + """Load TTS from UNIX socket.""" + try: + async with AsyncTcpClient(self.service.host, self.service.port) as client: + await client.write_event(Synthesize(message).event()) + + with io.BytesIO() as wav_io: + wav_writer: wave.Wave_write | None = None + while True: + event = await client.read_event() + if event is None: + _LOGGER.debug("Connection lost") + return (None, None) + + if AudioStop.is_type(event.type): + break + + if AudioChunk.is_type(event.type): + chunk = AudioChunk.from_event(event) + if wav_writer is None: + wav_writer = wave.open(wav_io, "wb") + wav_writer.setframerate(chunk.rate) + wav_writer.setsampwidth(chunk.width) + wav_writer.setnchannels(chunk.channels) + + wav_writer.writeframes(chunk.audio) + + if wav_writer is not None: + wav_writer.close() + + data = wav_io.getvalue() + + except (OSError, WyomingError): + return (None, None) + + if (options is None) or (options[tts.ATTR_AUDIO_OUTPUT] == "wav"): + return ("wav", data) + + # Raw output (convert to 16Khz, 16-bit mono) + with io.BytesIO(data) as wav_io: + wav_reader: wave.Wave_read = wave.open(wav_io, "rb") + raw_data = ( + AudioChunkConverter( + rate=16000, + width=2, + channels=1, + ) + .convert( + AudioChunk( + audio=wav_reader.readframes(wav_reader.getnframes()), + rate=wav_reader.getframerate(), + width=wav_reader.getsampwidth(), + channels=wav_reader.getnchannels(), + ) + ) + .audio + ) + + return ("raw", raw_data) diff --git a/tests/components/wyoming/__init__.py b/tests/components/wyoming/__init__.py index 5df845bb63a..d48b908f26b 100644 --- a/tests/components/wyoming/__init__.py +++ b/tests/components/wyoming/__init__.py @@ -1,5 +1,5 @@ """Tests for the Wyoming integration.""" -from wyoming.info import AsrModel, AsrProgram, Attribution, Info +from wyoming.info import AsrModel, AsrProgram, Attribution, Info, TtsProgram, TtsVoice TEST_ATTR = Attribution(name="Test", url="http://www.test.com") STT_INFO = Info( @@ -19,4 +19,53 @@ STT_INFO = Info( ) ] ) +TTS_INFO = Info( + tts=[ + TtsProgram( + name="Test TTS", + installed=True, + attribution=TEST_ATTR, + voices=[ + TtsVoice( + name="Test Voice", + installed=True, + attribution=TEST_ATTR, + languages=["en-US"], + ) + ], + ) + ] +) EMPTY_INFO = Info() + + +class MockAsyncTcpClient: + """Mock AsyncTcpClient.""" + + def __init__(self, responses) -> None: + """Initialize.""" + self.host = None + self.port = None + self.written = [] + self.responses = responses + + async def write_event(self, event): + """Send.""" + self.written.append(event) + + async def read_event(self): + """Receive.""" + return self.responses.pop(0) + + async def __aenter__(self): + """Enter.""" + return self + + async def __aexit__(self, exc_type, exc, tb): + """Exit.""" + + def __call__(self, host, port): + """Call.""" + self.host = host + self.port = port + return self diff --git a/tests/components/wyoming/conftest.py b/tests/components/wyoming/conftest.py index a3c83901453..0dd9041a0d5 100644 --- a/tests/components/wyoming/conftest.py +++ b/tests/components/wyoming/conftest.py @@ -7,7 +7,7 @@ import pytest from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant -from . import STT_INFO +from . import STT_INFO, TTS_INFO from tests.common import MockConfigEntry @@ -22,7 +22,7 @@ def mock_setup_entry() -> Generator[AsyncMock, None, None]: @pytest.fixture -def config_entry(hass: HomeAssistant) -> ConfigEntry: +def stt_config_entry(hass: HomeAssistant) -> ConfigEntry: """Create a config entry.""" entry = MockConfigEntry( domain="wyoming", @@ -37,10 +37,35 @@ def config_entry(hass: HomeAssistant) -> ConfigEntry: @pytest.fixture -async def init_wyoming_stt(hass: HomeAssistant, config_entry: ConfigEntry): - """Initialize Wyoming.""" +def tts_config_entry(hass: HomeAssistant) -> ConfigEntry: + """Create a config entry.""" + entry = MockConfigEntry( + domain="wyoming", + data={ + "host": "1.2.3.4", + "port": 1234, + }, + title="Test TTS", + ) + entry.add_to_hass(hass) + return entry + + +@pytest.fixture +async def init_wyoming_stt(hass: HomeAssistant, stt_config_entry: ConfigEntry): + """Initialize Wyoming STT.""" with patch( "homeassistant.components.wyoming.data.load_wyoming_info", return_value=STT_INFO, ): - await hass.config_entries.async_setup(config_entry.entry_id) + await hass.config_entries.async_setup(stt_config_entry.entry_id) + + +@pytest.fixture +async def init_wyoming_tts(hass: HomeAssistant, tts_config_entry: ConfigEntry): + """Initialize Wyoming TTS.""" + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=TTS_INFO, + ): + await hass.config_entries.async_setup(tts_config_entry.entry_id) diff --git a/tests/components/wyoming/snapshots/test_data.ambr b/tests/components/wyoming/snapshots/test_data.ambr new file mode 100644 index 00000000000..c47e40a0dc4 --- /dev/null +++ b/tests/components/wyoming/snapshots/test_data.ambr @@ -0,0 +1,11 @@ +# serializer version: 1 +# name: test_load_info + list([ + dict({ + 'data': dict({ + }), + 'payload': None, + 'type': 'describe', + }), + ]) +# --- diff --git a/tests/components/wyoming/snapshots/test_tts.ambr b/tests/components/wyoming/snapshots/test_tts.ambr new file mode 100644 index 00000000000..eb0b33c3276 --- /dev/null +++ b/tests/components/wyoming/snapshots/test_tts.ambr @@ -0,0 +1,23 @@ +# serializer version: 1 +# name: test_get_tts_audio + list([ + dict({ + 'data': dict({ + 'text': 'Hello world', + }), + 'payload': None, + 'type': 'synthesize', + }), + ]) +# --- +# name: test_get_tts_audio_raw + list([ + dict({ + 'data': dict({ + 'text': 'Hello world', + }), + 'payload': None, + 'type': 'synthesize', + }), + ]) +# --- diff --git a/tests/components/wyoming/test_config_flow.py b/tests/components/wyoming/test_config_flow.py index 9f9b123a411..54a5a2a8679 100644 --- a/tests/components/wyoming/test_config_flow.py +++ b/tests/components/wyoming/test_config_flow.py @@ -10,7 +10,7 @@ from homeassistant.components.wyoming.const import DOMAIN from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType -from . import EMPTY_INFO, STT_INFO +from . import EMPTY_INFO, STT_INFO, TTS_INFO from tests.common import MockConfigEntry @@ -26,7 +26,7 @@ ADDON_DISCOVERY = HassioServiceInfo( pytestmark = pytest.mark.usefixtures("mock_setup_entry") -async def test_form(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None: +async def test_form_stt(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None: """Test we get the form.""" result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} @@ -56,6 +56,36 @@ async def test_form(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None: assert len(mock_setup_entry.mock_calls) == 1 +async def test_form_tts(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None: + """Test we get the form.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == FlowResultType.FORM + assert result["errors"] is None + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=TTS_INFO, + ): + result2 = await hass.config_entries.flow.async_configure( + result["flow_id"], + { + "host": "1.1.1.1", + "port": 1234, + }, + ) + await hass.async_block_till_done() + + assert result2["type"] == FlowResultType.CREATE_ENTRY + assert result2["title"] == "Test TTS" + assert result2["data"] == { + "host": "1.1.1.1", + "port": 1234, + } + assert len(mock_setup_entry.mock_calls) == 1 + + async def test_form_cannot_connect(hass: HomeAssistant) -> None: """Test we handle cannot connect error.""" result = await hass.config_entries.flow.async_init( diff --git a/tests/components/wyoming/test_data.py b/tests/components/wyoming/test_data.py new file mode 100644 index 00000000000..0cb878c39c1 --- /dev/null +++ b/tests/components/wyoming/test_data.py @@ -0,0 +1,40 @@ +"""Test tts.""" +from __future__ import annotations + +from unittest.mock import patch + +from homeassistant.components.wyoming.data import load_wyoming_info +from homeassistant.core import HomeAssistant + +from . import STT_INFO, MockAsyncTcpClient + + +async def test_load_info(hass: HomeAssistant, snapshot) -> None: + """Test loading info.""" + with patch( + "homeassistant.components.wyoming.data.AsyncTcpClient", + MockAsyncTcpClient([STT_INFO.event()]), + ) as mock_client: + info = await load_wyoming_info("localhost", 1234) + + assert info == STT_INFO + assert mock_client.written == snapshot + + +async def test_load_info_oserror(hass: HomeAssistant) -> None: + """Test loading info and error raising.""" + mock_client = MockAsyncTcpClient([STT_INFO.event()]) + + with patch( + "homeassistant.components.wyoming.data.AsyncTcpClient", + mock_client, + ), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")): + info = await load_wyoming_info( + "localhost", + 1234, + retries=0, + retry_wait=0, + timeout=0.001, + ) + + assert info is None diff --git a/tests/components/wyoming/test_init.py b/tests/components/wyoming/test_init.py index 1a8b89d9b5e..85539f5a164 100644 --- a/tests/components/wyoming/test_init.py +++ b/tests/components/wyoming/test_init.py @@ -5,17 +5,19 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant -async def test_cannot_connect(hass: HomeAssistant, config_entry: ConfigEntry) -> None: +async def test_cannot_connect( + hass: HomeAssistant, stt_config_entry: ConfigEntry +) -> None: """Test we handle cannot connect error.""" with patch( "homeassistant.components.wyoming.data.load_wyoming_info", return_value=None, ): - assert not await hass.config_entries.async_setup(config_entry.entry_id) + assert not await hass.config_entries.async_setup(stt_config_entry.entry_id) async def test_unload( - hass: HomeAssistant, config_entry: ConfigEntry, init_wyoming_stt + hass: HomeAssistant, stt_config_entry: ConfigEntry, init_wyoming_stt ) -> None: """Test unload.""" - assert await hass.config_entries.async_unload(config_entry.entry_id) + assert await hass.config_entries.async_unload(stt_config_entry.entry_id) diff --git a/tests/components/wyoming/test_stt.py b/tests/components/wyoming/test_stt.py index 1f73426e9f9..6c9e75ffa18 100644 --- a/tests/components/wyoming/test_stt.py +++ b/tests/components/wyoming/test_stt.py @@ -3,50 +3,21 @@ from __future__ import annotations from unittest.mock import patch -from wyoming.event import Event +from wyoming.asr import Transcript from homeassistant.components import stt from homeassistant.core import HomeAssistant - -class MockAsyncTcpClient: - """Mock AsyncTcpClient.""" - - def __init__(self, responses) -> None: - """Initialize.""" - self.host = None - self.port = None - self.written = [] - self.responses = responses - - async def write_event(self, event): - """Send.""" - self.written.append(event) - - async def read_event(self): - """Receive.""" - return self.responses.pop(0) - - async def __aenter__(self): - """Enter.""" - return self - - async def __aexit__(self, exc_type, exc, tb): - """Exit.""" - - def __call__(self, host, port): - """Call.""" - self.host = host - self.port = port - return self +from . import MockAsyncTcpClient async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None: - """Test streaming audio.""" + """Test supported properties.""" state = hass.states.get("stt.wyoming") assert state is not None entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") + assert entity is not None assert entity.supported_languages == ["en-US"] assert entity.supported_formats == [stt.AudioFormats.WAV] @@ -59,6 +30,7 @@ async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None: async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None: """Test streaming audio.""" entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") + assert entity is not None async def audio_stream(): yield "chunk1" @@ -66,7 +38,7 @@ async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) with patch( "homeassistant.components.wyoming.stt.AsyncTcpClient", - MockAsyncTcpClient([Event(type="transcript", data={"text": "Hello world"})]), + MockAsyncTcpClient([Transcript(text="Hello world").event()]), ) as mock_client: result = await entity.async_process_audio_stream(None, audio_stream()) @@ -80,6 +52,7 @@ async def test_streaming_audio_connection_lost( ) -> None: """Test streaming audio and losing connection.""" entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") + assert entity is not None async def audio_stream(): yield "chunk1" @@ -97,13 +70,12 @@ async def test_streaming_audio_connection_lost( async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None: """Test streaming audio and error raising.""" entity = stt.async_get_speech_to_text_entity(hass, "stt.wyoming") + assert entity is not None async def audio_stream(): yield "chunk1" - mock_client = MockAsyncTcpClient( - [Event(type="transcript", data={"text": "Hello world"})] - ) + mock_client = MockAsyncTcpClient([Transcript(text="Hello world").event()]) with patch( "homeassistant.components.wyoming.stt.AsyncTcpClient", diff --git a/tests/components/wyoming/test_tts.py b/tests/components/wyoming/test_tts.py new file mode 100644 index 00000000000..69580e6456c --- /dev/null +++ b/tests/components/wyoming/test_tts.py @@ -0,0 +1,143 @@ +"""Test tts.""" +from __future__ import annotations + +import io +from unittest.mock import patch +import wave + +import pytest +from wyoming.audio import AudioChunk, AudioStop + +from homeassistant.components import tts +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError +from homeassistant.helpers.entity_component import DATA_INSTANCES + +from . import MockAsyncTcpClient + +from tests.components.tts.conftest import ( # noqa: F401, pylint: disable=unused-import + init_cache_dir_side_effect, + mock_get_cache_files, + mock_init_cache_dir, +) + + +async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None: + """Test supported properties.""" + state = hass.states.get("tts.test_tts") + assert state is not None + + entity = hass.data[DATA_INSTANCES]["tts"].get_entity("tts.test_tts") + assert entity is not None + + assert entity.supported_languages == ["en-US"] + assert entity.supported_options == [tts.ATTR_AUDIO_OUTPUT, tts.ATTR_VOICE] + voices = entity.async_get_supported_voices("en-US") + assert len(voices) == 1 + assert voices[0].name == "Test Voice" + assert voices[0].voice_id == "Test Voice" + assert not entity.async_get_supported_voices("de-DE") + + +async def test_get_tts_audio(hass: HomeAssistant, init_wyoming_tts, snapshot) -> None: + """Test get audio.""" + audio = bytes(100) + audio_events = [ + AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), + AudioStop().event(), + ] + + with patch( + "homeassistant.components.wyoming.tts.AsyncTcpClient", + MockAsyncTcpClient(audio_events), + ) as mock_client: + extension, data = await tts.async_get_media_source_audio( + hass, + tts.generate_media_source_id( + hass, "Hello world", "tts.test_tts", hass.config.language + ), + ) + + assert extension == "wav" + assert data is not None + with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file: + assert wav_file.getframerate() == 16000 + assert wav_file.getsampwidth() == 2 + assert wav_file.getnchannels() == 1 + assert wav_file.readframes(wav_file.getnframes()) == audio + + assert mock_client.written == snapshot + + +async def test_get_tts_audio_raw( + hass: HomeAssistant, init_wyoming_tts, snapshot +) -> None: + """Test get raw audio.""" + audio = bytes(100) + audio_events = [ + AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), + AudioStop().event(), + ] + + with patch( + "homeassistant.components.wyoming.tts.AsyncTcpClient", + MockAsyncTcpClient(audio_events), + ) as mock_client: + extension, data = await tts.async_get_media_source_audio( + hass, + tts.generate_media_source_id( + hass, + "Hello world", + "tts.test_tts", + hass.config.language, + options={tts.ATTR_AUDIO_OUTPUT: "raw"}, + ), + ) + + assert extension == "raw" + assert data == audio + assert mock_client.written == snapshot + + +async def test_get_tts_audio_connection_lost( + hass: HomeAssistant, init_wyoming_tts +) -> None: + """Test streaming audio and losing connection.""" + with patch( + "homeassistant.components.wyoming.tts.AsyncTcpClient", + MockAsyncTcpClient([None]), + ), pytest.raises(HomeAssistantError): + await tts.async_get_media_source_audio( + hass, + tts.generate_media_source_id( + hass, "Hello world", "tts.test_tts", hass.config.language + ), + ) + + +async def test_get_tts_audio_audio_oserror( + hass: HomeAssistant, init_wyoming_tts +) -> None: + """Test get audio and error raising.""" + audio = bytes(100) + audio_events = [ + AudioChunk(audio=audio, rate=16000, width=2, channels=1).event(), + AudioStop().event(), + ] + + mock_client = MockAsyncTcpClient(audio_events) + + with patch( + "homeassistant.components.wyoming.tts.AsyncTcpClient", + mock_client, + ), patch.object( + mock_client, "read_event", side_effect=OSError("Boom!") + ), pytest.raises( + HomeAssistantError + ): + await tts.async_get_media_source_audio( + hass, + tts.generate_media_source_id( + hass, "Hello world", "tts.test_tts", hass.config.language + ), + )