Use correct service name with Wyoming satellite + local wake word detection (#111870)

* Use correct service name with satellite + local wake word detection

* Don't load platforms for satellite services

* Update homeassistant/components/wyoming/data.py

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>

* Fix ruff error

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Michael Hansen 2024-02-29 12:09:38 -06:00 committed by GitHub
parent 66b17a8e0d
commit f0deae319e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 45 additions and 19 deletions

View File

@ -1,10 +1,11 @@
"""Base class for Wyoming providers."""
from __future__ import annotations
import asyncio
from wyoming.client import AsyncTcpClient
from wyoming.info import Describe, Info, Satellite
from wyoming.info import Describe, Info
from homeassistant.const import Platform
@ -23,14 +24,19 @@ class WyomingService:
self.host = host
self.port = port
self.info = info
platforms = []
self.platforms = []
if (self.info.satellite is not None) and self.info.satellite.installed:
# Don't load platforms for satellite services, such as local wake
# word detection.
return
if any(asr.installed for asr in info.asr):
platforms.append(Platform.STT)
self.platforms.append(Platform.STT)
if any(tts.installed for tts in info.tts):
platforms.append(Platform.TTS)
self.platforms.append(Platform.TTS)
if any(wake.installed for wake in info.wake):
platforms.append(Platform.WAKE_WORD)
self.platforms = platforms
self.platforms.append(Platform.WAKE_WORD)
def has_services(self) -> bool:
"""Return True if services are installed that Home Assistant can use."""
@ -43,6 +49,12 @@ class WyomingService:
def get_name(self) -> str | None:
"""Return name of first installed usable service."""
# Wyoming satellite
# Must be checked first because satellites may contain wake services, etc.
if (self.info.satellite is not None) and self.info.satellite.installed:
return self.info.satellite.name
# ASR = automated speech recognition (speech-to-text)
asr_installed = [asr for asr in self.info.asr if asr.installed]
if asr_installed:
@ -58,15 +70,6 @@ class WyomingService:
if wake_installed:
return wake_installed[0].name
# satellite
satellite_installed: Satellite | None = None
if (self.info.satellite is not None) and self.info.satellite.installed:
satellite_installed = self.info.satellite
if satellite_installed:
return satellite_installed.name
return None
@classmethod

View File

@ -1,9 +1,11 @@
"""Test tts."""
from __future__ import annotations
from unittest.mock import patch
from syrupy.assertion import SnapshotAssertion
from wyoming.info import Info
from homeassistant.components.wyoming.data import WyomingService, load_wyoming_info
from homeassistant.core import HomeAssistant
@ -27,10 +29,13 @@ async def test_load_info_oserror(hass: HomeAssistant) -> None:
"""Test loading info and error raising."""
mock_client = MockAsyncTcpClient([STT_INFO.event()])
with patch(
with (
patch(
"homeassistant.components.wyoming.data.AsyncTcpClient",
mock_client,
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
),
patch.object(mock_client, "read_event", side_effect=OSError("Boom!")),
):
info = await load_wyoming_info(
"localhost",
1234,
@ -75,3 +80,21 @@ async def test_service_name(hass: HomeAssistant) -> None:
service = await WyomingService.create("localhost", 1234)
assert service is not None
assert service.get_name() == SATELLITE_INFO.satellite.name
async def test_satellite_with_wake_word(hass: HomeAssistant) -> None:
"""Test that wake word info with satellite doesn't overwrite the service name."""
# Info for local wake word detection
satellite_info = Info(
satellite=SATELLITE_INFO.satellite,
wake=WAKE_WORD_INFO.wake,
)
with patch(
"homeassistant.components.wyoming.data.AsyncTcpClient",
MockAsyncTcpClient([satellite_info.event()]),
):
service = await WyomingService.create("localhost", 1234)
assert service is not None
assert service.get_name() == satellite_info.satellite.name
assert not service.platforms