Use correct service name with Wyoming satellite + local wake word detection (#111870)

* Use correct service name with satellite + local wake word detection

* Don't load platforms for satellite services

* Update homeassistant/components/wyoming/data.py

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>

* Fix ruff error

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2024-02-29 12:09:38 -06:00 committed by GitHub
parent 66b17a8e0d
commit f0deae319e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 45 additions and 19 deletions

View File

@ -1,10 +1,11 @@
"""Base class for Wyoming providers.""" """Base class for Wyoming providers."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from wyoming.client import AsyncTcpClient from wyoming.client import AsyncTcpClient
from wyoming.info import Describe, Info, Satellite from wyoming.info import Describe, Info
from homeassistant.const import Platform from homeassistant.const import Platform
@ -23,14 +24,19 @@ class WyomingService:
self.host = host self.host = host
self.port = port self.port = port
self.info = info self.info = info
platforms = [] self.platforms = []
if (self.info.satellite is not None) and self.info.satellite.installed:
# Don't load platforms for satellite services, such as local wake
# word detection.
return
if any(asr.installed for asr in info.asr): if any(asr.installed for asr in info.asr):
platforms.append(Platform.STT) self.platforms.append(Platform.STT)
if any(tts.installed for tts in info.tts): if any(tts.installed for tts in info.tts):
platforms.append(Platform.TTS) self.platforms.append(Platform.TTS)
if any(wake.installed for wake in info.wake): if any(wake.installed for wake in info.wake):
platforms.append(Platform.WAKE_WORD) self.platforms.append(Platform.WAKE_WORD)
self.platforms = platforms
def has_services(self) -> bool: def has_services(self) -> bool:
"""Return True if services are installed that Home Assistant can use.""" """Return True if services are installed that Home Assistant can use."""
@ -43,6 +49,12 @@ class WyomingService:
def get_name(self) -> str | None: def get_name(self) -> str | None:
"""Return name of first installed usable service.""" """Return name of first installed usable service."""
# Wyoming satellite
# Must be checked first because satellites may contain wake services, etc.
if (self.info.satellite is not None) and self.info.satellite.installed:
return self.info.satellite.name
# ASR = automated speech recognition (speech-to-text) # ASR = automated speech recognition (speech-to-text)
asr_installed = [asr for asr in self.info.asr if asr.installed] asr_installed = [asr for asr in self.info.asr if asr.installed]
if asr_installed: if asr_installed:
@ -58,15 +70,6 @@ class WyomingService:
if wake_installed: if wake_installed:
return wake_installed[0].name return wake_installed[0].name
# satellite
satellite_installed: Satellite | None = None
if (self.info.satellite is not None) and self.info.satellite.installed:
satellite_installed = self.info.satellite
if satellite_installed:
return satellite_installed.name
return None return None
@classmethod @classmethod

View File

@ -1,9 +1,11 @@
"""Test tts.""" """Test tts."""
from __future__ import annotations from __future__ import annotations
from unittest.mock import patch from unittest.mock import patch
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from wyoming.info import Info
from homeassistant.components.wyoming.data import WyomingService, load_wyoming_info from homeassistant.components.wyoming.data import WyomingService, load_wyoming_info
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -27,10 +29,13 @@ async def test_load_info_oserror(hass: HomeAssistant) -> None:
"""Test loading info and error raising.""" """Test loading info and error raising."""
mock_client = MockAsyncTcpClient([STT_INFO.event()]) mock_client = MockAsyncTcpClient([STT_INFO.event()])
with patch( with (
"homeassistant.components.wyoming.data.AsyncTcpClient", patch(
mock_client, "homeassistant.components.wyoming.data.AsyncTcpClient",
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")): mock_client,
),
patch.object(mock_client, "read_event", side_effect=OSError("Boom!")),
):
info = await load_wyoming_info( info = await load_wyoming_info(
"localhost", "localhost",
1234, 1234,
@ -75,3 +80,21 @@ async def test_service_name(hass: HomeAssistant) -> None:
service = await WyomingService.create("localhost", 1234) service = await WyomingService.create("localhost", 1234)
assert service is not None assert service is not None
assert service.get_name() == SATELLITE_INFO.satellite.name assert service.get_name() == SATELLITE_INFO.satellite.name
async def test_satellite_with_wake_word(hass: HomeAssistant) -> None:
"""Test that wake word info with satellite doesn't overwrite the service name."""
# Info for local wake word detection
satellite_info = Info(
satellite=SATELLITE_INFO.satellite,
wake=WAKE_WORD_INFO.wake,
)
with patch(
"homeassistant.components.wyoming.data.AsyncTcpClient",
MockAsyncTcpClient([satellite_info.event()]),
):
service = await WyomingService.create("localhost", 1234)
assert service is not None
assert service.get_name() == satellite_info.satellite.name
assert not service.platforms