mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 19:27:45 +00:00
Use correct service name with Wyoming satellite + local wake word detection (#111870)
* Use correct service name with satellite + local wake word detection * Don't load platforms for satellite services * Update homeassistant/components/wyoming/data.py Co-authored-by: Paulus Schoutsen <balloob@gmail.com> * Fix ruff error --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
66b17a8e0d
commit
f0deae319e
@ -1,10 +1,11 @@
|
|||||||
"""Base class for Wyoming providers."""
|
"""Base class for Wyoming providers."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from wyoming.client import AsyncTcpClient
|
from wyoming.client import AsyncTcpClient
|
||||||
from wyoming.info import Describe, Info, Satellite
|
from wyoming.info import Describe, Info
|
||||||
|
|
||||||
from homeassistant.const import Platform
|
from homeassistant.const import Platform
|
||||||
|
|
||||||
@ -23,14 +24,19 @@ class WyomingService:
|
|||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.info = info
|
self.info = info
|
||||||
platforms = []
|
self.platforms = []
|
||||||
|
|
||||||
|
if (self.info.satellite is not None) and self.info.satellite.installed:
|
||||||
|
# Don't load platforms for satellite services, such as local wake
|
||||||
|
# word detection.
|
||||||
|
return
|
||||||
|
|
||||||
if any(asr.installed for asr in info.asr):
|
if any(asr.installed for asr in info.asr):
|
||||||
platforms.append(Platform.STT)
|
self.platforms.append(Platform.STT)
|
||||||
if any(tts.installed for tts in info.tts):
|
if any(tts.installed for tts in info.tts):
|
||||||
platforms.append(Platform.TTS)
|
self.platforms.append(Platform.TTS)
|
||||||
if any(wake.installed for wake in info.wake):
|
if any(wake.installed for wake in info.wake):
|
||||||
platforms.append(Platform.WAKE_WORD)
|
self.platforms.append(Platform.WAKE_WORD)
|
||||||
self.platforms = platforms
|
|
||||||
|
|
||||||
def has_services(self) -> bool:
|
def has_services(self) -> bool:
|
||||||
"""Return True if services are installed that Home Assistant can use."""
|
"""Return True if services are installed that Home Assistant can use."""
|
||||||
@ -43,6 +49,12 @@ class WyomingService:
|
|||||||
|
|
||||||
def get_name(self) -> str | None:
|
def get_name(self) -> str | None:
|
||||||
"""Return name of first installed usable service."""
|
"""Return name of first installed usable service."""
|
||||||
|
|
||||||
|
# Wyoming satellite
|
||||||
|
# Must be checked first because satellites may contain wake services, etc.
|
||||||
|
if (self.info.satellite is not None) and self.info.satellite.installed:
|
||||||
|
return self.info.satellite.name
|
||||||
|
|
||||||
# ASR = automated speech recognition (speech-to-text)
|
# ASR = automated speech recognition (speech-to-text)
|
||||||
asr_installed = [asr for asr in self.info.asr if asr.installed]
|
asr_installed = [asr for asr in self.info.asr if asr.installed]
|
||||||
if asr_installed:
|
if asr_installed:
|
||||||
@ -58,15 +70,6 @@ class WyomingService:
|
|||||||
if wake_installed:
|
if wake_installed:
|
||||||
return wake_installed[0].name
|
return wake_installed[0].name
|
||||||
|
|
||||||
# satellite
|
|
||||||
satellite_installed: Satellite | None = None
|
|
||||||
|
|
||||||
if (self.info.satellite is not None) and self.info.satellite.installed:
|
|
||||||
satellite_installed = self.info.satellite
|
|
||||||
|
|
||||||
if satellite_installed:
|
|
||||||
return satellite_installed.name
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1,9 +1,11 @@
|
|||||||
"""Test tts."""
|
"""Test tts."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
from wyoming.info import Info
|
||||||
|
|
||||||
from homeassistant.components.wyoming.data import WyomingService, load_wyoming_info
|
from homeassistant.components.wyoming.data import WyomingService, load_wyoming_info
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@ -27,10 +29,13 @@ async def test_load_info_oserror(hass: HomeAssistant) -> None:
|
|||||||
"""Test loading info and error raising."""
|
"""Test loading info and error raising."""
|
||||||
mock_client = MockAsyncTcpClient([STT_INFO.event()])
|
mock_client = MockAsyncTcpClient([STT_INFO.event()])
|
||||||
|
|
||||||
with patch(
|
with (
|
||||||
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
patch(
|
||||||
mock_client,
|
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||||
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
|
mock_client,
|
||||||
|
),
|
||||||
|
patch.object(mock_client, "read_event", side_effect=OSError("Boom!")),
|
||||||
|
):
|
||||||
info = await load_wyoming_info(
|
info = await load_wyoming_info(
|
||||||
"localhost",
|
"localhost",
|
||||||
1234,
|
1234,
|
||||||
@ -75,3 +80,21 @@ async def test_service_name(hass: HomeAssistant) -> None:
|
|||||||
service = await WyomingService.create("localhost", 1234)
|
service = await WyomingService.create("localhost", 1234)
|
||||||
assert service is not None
|
assert service is not None
|
||||||
assert service.get_name() == SATELLITE_INFO.satellite.name
|
assert service.get_name() == SATELLITE_INFO.satellite.name
|
||||||
|
|
||||||
|
|
||||||
|
async def test_satellite_with_wake_word(hass: HomeAssistant) -> None:
|
||||||
|
"""Test that wake word info with satellite doesn't overwrite the service name."""
|
||||||
|
# Info for local wake word detection
|
||||||
|
satellite_info = Info(
|
||||||
|
satellite=SATELLITE_INFO.satellite,
|
||||||
|
wake=WAKE_WORD_INFO.wake,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||||
|
MockAsyncTcpClient([satellite_info.event()]),
|
||||||
|
):
|
||||||
|
service = await WyomingService.create("localhost", 1234)
|
||||||
|
assert service is not None
|
||||||
|
assert service.get_name() == satellite_info.satellite.name
|
||||||
|
assert not service.platforms
|
||||||
|
Loading…
x
Reference in New Issue
Block a user