diff --git a/homeassistant/components/wyoming/data.py b/homeassistant/components/wyoming/data.py index adcb472d5e0..e333a740741 100644 --- a/homeassistant/components/wyoming/data.py +++ b/homeassistant/components/wyoming/data.py @@ -1,10 +1,11 @@ """Base class for Wyoming providers.""" + from __future__ import annotations import asyncio from wyoming.client import AsyncTcpClient -from wyoming.info import Describe, Info, Satellite +from wyoming.info import Describe, Info from homeassistant.const import Platform @@ -23,14 +24,19 @@ class WyomingService: self.host = host self.port = port self.info = info - platforms = [] + self.platforms = [] + + if (self.info.satellite is not None) and self.info.satellite.installed: + # Don't load platforms for satellite services, such as local wake + # word detection. + return + if any(asr.installed for asr in info.asr): - platforms.append(Platform.STT) + self.platforms.append(Platform.STT) if any(tts.installed for tts in info.tts): - platforms.append(Platform.TTS) + self.platforms.append(Platform.TTS) if any(wake.installed for wake in info.wake): - platforms.append(Platform.WAKE_WORD) - self.platforms = platforms + self.platforms.append(Platform.WAKE_WORD) def has_services(self) -> bool: """Return True if services are installed that Home Assistant can use.""" @@ -43,6 +49,12 @@ class WyomingService: def get_name(self) -> str | None: """Return name of first installed usable service.""" + + # Wyoming satellite + # Must be checked first because satellites may contain wake services, etc. + if (self.info.satellite is not None) and self.info.satellite.installed: + return self.info.satellite.name + # ASR = automated speech recognition (speech-to-text) asr_installed = [asr for asr in self.info.asr if asr.installed] if asr_installed: @@ -58,15 +70,6 @@ class WyomingService: if wake_installed: return wake_installed[0].name - # satellite - satellite_installed: Satellite | None = None - - if (self.info.satellite is not None) and self.info.satellite.installed: - satellite_installed = self.info.satellite - - if satellite_installed: - return satellite_installed.name - return None @classmethod diff --git a/tests/components/wyoming/test_data.py b/tests/components/wyoming/test_data.py index b7de9dbfdc1..282326b2ce0 100644 --- a/tests/components/wyoming/test_data.py +++ b/tests/components/wyoming/test_data.py @@ -1,9 +1,11 @@ """Test tts.""" + from __future__ import annotations from unittest.mock import patch from syrupy.assertion import SnapshotAssertion +from wyoming.info import Info from homeassistant.components.wyoming.data import WyomingService, load_wyoming_info from homeassistant.core import HomeAssistant @@ -27,10 +29,13 @@ async def test_load_info_oserror(hass: HomeAssistant) -> None: """Test loading info and error raising.""" mock_client = MockAsyncTcpClient([STT_INFO.event()]) - with patch( - "homeassistant.components.wyoming.data.AsyncTcpClient", - mock_client, - ), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")): + with ( + patch( + "homeassistant.components.wyoming.data.AsyncTcpClient", + mock_client, + ), + patch.object(mock_client, "read_event", side_effect=OSError("Boom!")), + ): info = await load_wyoming_info( "localhost", 1234, @@ -75,3 +80,21 @@ async def test_service_name(hass: HomeAssistant) -> None: service = await WyomingService.create("localhost", 1234) assert service is not None assert service.get_name() == SATELLITE_INFO.satellite.name + + +async def test_satellite_with_wake_word(hass: HomeAssistant) -> None: + """Test that wake word info with satellite doesn't overwrite the service name.""" + # Info for local wake word detection + satellite_info = Info( + satellite=SATELLITE_INFO.satellite, + wake=WAKE_WORD_INFO.wake, + ) + + with patch( + "homeassistant.components.wyoming.data.AsyncTcpClient", + MockAsyncTcpClient([satellite_info.event()]), + ): + service = await WyomingService.create("localhost", 1234) + assert service is not None + assert service.get_name() == satellite_info.satellite.name + assert not service.platforms